aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/sca/re/tree.py29
-rw-r--r--pyproject.toml3
-rw-r--r--test/sca/test_tree.py2
3 files changed, 26 insertions, 8 deletions
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py
index 43f1901..a45faaa 100644
--- a/pyecsca/sca/re/tree.py
+++ b/pyecsca/sca/re/tree.py
@@ -45,11 +45,12 @@ Here we grow the trees.
from math import ceil, log2
from copy import deepcopy
-from typing import Mapping, Any, Set, List, Tuple, Optional, Dict, Union
+from typing import Mapping, Any, Callable, Set, List, Tuple, Optional, Dict, Union
import numpy as np
import pandas as pd
from public import public
+from tqdm.auto import tqdm
from anytree import RenderTree, NodeMixin, AbstractStyle, PreOrderIter
from pyecsca.misc.utils import log
@@ -556,16 +557,26 @@ class Tree:
return tree
@classmethod
- def build(cls, cfgs: Set[Any], *maps: Map, split: Union[str, SplitCriterion] = "largest") -> "Tree":
+ def build(cls, cfgs: Set[Any], *maps: Map, split: Union[str, SplitCriterion] = "largest", progress: bool = False) -> "Tree":
"""
Build a tree.
:param cfgs: The set of configs to build the tree for.
:param maps: The distinguishing maps to use.
:param split: The split criterion to use. Can be one of "degree", "largest", "average".
+ :param progress: Whether to report progress via a progress bar.
:return: The tree.
"""
- return cls(_build_tree(cfgs, dict(enumerate(maps)), split=split), *maps)
+ leaf_callback: Optional[Callable[[int], None]]
+ if progress:
+ bar = tqdm(total=len(cfgs), desc="Building tree.")
+ def leaf_callback(n_cfgs): return bar.update(n_cfgs)
+ else:
+ leaf_callback = None
+ res = cls(_build_tree(cfgs, dict(enumerate(maps)), split=split, leaf_callback=leaf_callback), *maps)
+ if progress:
+ bar.close()
+ return res
def _build_tree(
@@ -576,17 +587,19 @@ def _build_tree(
index: Optional[int] = None,
breadth: Optional[int] = None,
split: Union[str, SplitCriterion] = "largest",
+ leaf_callback: Optional[Callable[[int], None]] = None
) -> Node:
- pad = " " * depth
+ pad = " " * depth + (f"[{depth}, {index}/{breadth}] " if index is not None and breadth is not None else "")
# If there is only one remaining cfg, we do not need to continue and just return (base case 1).
# Note that n_cfgs will never be 0 here, as the base case 2 returns if the cfgs cannot
# be split into two sets (one would be empty).
n_cfgs = len(cfgs)
- ident = f"[{index}/{breadth}] " if index is not None and breadth is not None else ""
- log(pad + f"{ident}Splitting {n_cfgs}.")
+ log(pad + f"Splitting {n_cfgs}.")
cfgset = set(cfgs)
if n_cfgs == 1:
log(pad + "Trivial.")
+ if leaf_callback is not None:
+ leaf_callback(n_cfgs)
return Node(cfgset, response=response)
# Choose the split criterion
@@ -627,6 +640,8 @@ def _build_tree(
# We found nothing distinguishing the configs, so return them all (base case 2).
if best_column is None or best_dmap is None:
log(pad + "Nothing could split.")
+ if leaf_callback is not None:
+ leaf_callback(n_cfgs)
return Node(cfgset, response=response)
best_distinguishing_element = best_dmap.domain[best_column]
@@ -636,6 +651,8 @@ def _build_tree(
# We found nothing distinguishing the configs, so return them all (base case 2).
if groups.ngroups == 1:
log(pad + "Trivial split.")
+ if leaf_callback is not None:
+ leaf_callback(n_cfgs)
return Node(cfgset, response=response)
# Create our node
diff --git a/pyproject.toml b/pyproject.toml
index 83674c7..da0a9f1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,7 +55,8 @@
"networkx",
"importlib-resources",
"anytree",
- "loky"
+ "loky",
+ "tqdm"
]
[project.urls]
diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py
index a3afaf9..715c3e0 100644
--- a/test/sca/test_tree.py
+++ b/test/sca/test_tree.py
@@ -95,7 +95,7 @@ def test_build_tree(split):
codomain2 = {0, 1, 2, 3}
mapping2 = pd.DataFrame([(1, 0, 0), (2, 0, 0), (3, 0, 0)])
dmap2 = Map(mapping2, cfg_map, inputs2, codomain2)
- tree = Tree.build(set(cfgs), dmap1, dmap2, split=split)
+ tree = Tree.build(set(cfgs), dmap1, dmap2, split=split, progress=True)
tree.render()
tree.render_basic()
tree.describe()