diff options
| -rw-r--r-- | pyecsca/sca/re/tree.py | 17 | ||||
| -rw-r--r-- | test/sca/test_tree.py | 17 |
2 files changed, 25 insertions, 9 deletions
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py index 2955631..8e70ab0 100644 --- a/pyecsca/sca/re/tree.py +++ b/pyecsca/sca/re/tree.py @@ -158,7 +158,9 @@ class Map: indices.append(thing.index) return thing.iloc[0] - self.mapping = self.mapping.groupby(self.mapping.columns.tolist(), as_index=False, dropna=False).agg(agg) + self.mapping = self.mapping.groupby( + self.mapping.columns.tolist(), as_index=False, dropna=False + ).agg(agg) new_cfg_map = self.cfg_map.copy() for i, index in enumerate(indices): new_cfg_map.loc[self.cfg_map["vals"].isin(index), "vals"] = i @@ -189,9 +191,13 @@ class Node(NodeMixin): """A node in a distinguishing tree.""" cfgs: Set[Any] + """Set of configs associated with this node.""" + response: Optional[Any] + """The response to the *previous* oracle call that resulted in this node.""" dmap_index: Optional[int] + """The dmap index to be used for the oracle call for this node.""" dmap_input: Optional[Any] - response: Optional[Any] + """The input for the oracle call for this node (is from dmap at dmap_index in the Tree).""" def __init__( self, @@ -217,7 +223,9 @@ class Tree: """A distinguishing tree.""" maps: List[Map] + """A list of dmaps. Nodes index into this when choosing which oracle to use.""" root: Node + """A root of the tree.""" def __init__(self, root: Node, *maps: Map): self.maps = list(maps) @@ -225,17 +233,21 @@ class Tree: @property def leaves(self) -> Tuple[Node]: + """Get the leaves of the tree as a tuple.""" return self.root.leaves @property def height(self) -> int: + """Get the height of the tree (distance from the root to the deepest leaf).""" return self.root.height @property def size(self) -> int: + """Get the size of the tree (number of nodes).""" return self.root.size def render(self) -> str: + """Render the tree.""" style = AbstractStyle("\u2502 ", "\u251c\u2500\u2500", "\u2514\u2500\u2500") def _str(n: Node): @@ -248,6 +260,7 @@ class Tree: return RenderTree(self.root, style=style).by_attr(_str) def describe(self) -> str: + """Describe some important properties of the tree.""" leaf_sizes = [len(leaf.cfgs) for leaf in self.leaves] leafs: List[int] = sum(([size] * size for size in leaf_sizes), []) return "\n".join( diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py index 89a9f61..234732d 100644 --- a/test/sca/test_tree.py +++ b/test/sca/test_tree.py @@ -1,5 +1,4 @@ import random -import time from copy import deepcopy from pyecsca.sca.re.tree import Tree, Map @@ -31,12 +30,16 @@ def test_map_merge(): binary_sets = {"c": {1, 2}, "d": {2, 4, 3}} dmap2 = Map.from_sets(cfgs, binary_sets) assert len(dmap2.mapping) == 2 - dmap1.merge(dmap2) - assert len(dmap1.mapping) == 4 - assert len(dmap1.cfg_map) == 4 - assert len(dmap1.codomain) == 2 - assert not dmap1["c", 3] - assert dmap1["a", 1] + merged = deepcopy(dmap1) + merged.merge(dmap2) + assert len(merged.mapping) == 4 + assert len(merged.cfg_map) == 4 + assert len(merged.codomain) == 2 + for i in [1, 2, 3, 4]: + for cfg in "ab": + assert merged[cfg, i] == dmap1[cfg, i] + for cfg in "cd": + assert merged[cfg, i] == dmap2[cfg, i] def test_map_deduplicate(): |
