aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pyecsca/sca/re/tree.py17
-rw-r--r--test/sca/test_tree.py17
2 files changed, 25 insertions, 9 deletions
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py
index 2955631..8e70ab0 100644
--- a/pyecsca/sca/re/tree.py
+++ b/pyecsca/sca/re/tree.py
@@ -158,7 +158,9 @@ class Map:
indices.append(thing.index)
return thing.iloc[0]
- self.mapping = self.mapping.groupby(self.mapping.columns.tolist(), as_index=False, dropna=False).agg(agg)
+ self.mapping = self.mapping.groupby(
+ self.mapping.columns.tolist(), as_index=False, dropna=False
+ ).agg(agg)
new_cfg_map = self.cfg_map.copy()
for i, index in enumerate(indices):
new_cfg_map.loc[self.cfg_map["vals"].isin(index), "vals"] = i
@@ -189,9 +191,13 @@ class Node(NodeMixin):
"""A node in a distinguishing tree."""
cfgs: Set[Any]
+ """Set of configs associated with this node."""
+ response: Optional[Any]
+ """The response to the *previous* oracle call that resulted in this node."""
dmap_index: Optional[int]
+ """The dmap index to be used for the oracle call for this node."""
dmap_input: Optional[Any]
- response: Optional[Any]
+ """The input for the oracle call for this node (is from dmap at dmap_index in the Tree)."""
def __init__(
self,
@@ -217,7 +223,9 @@ class Tree:
"""A distinguishing tree."""
maps: List[Map]
+ """A list of dmaps. Nodes index into this when choosing which oracle to use."""
root: Node
+ """A root of the tree."""
def __init__(self, root: Node, *maps: Map):
self.maps = list(maps)
@@ -225,17 +233,21 @@ class Tree:
@property
def leaves(self) -> Tuple[Node]:
+ """Get the leaves of the tree as a tuple."""
return self.root.leaves
@property
def height(self) -> int:
+ """Get the height of the tree (distance from the root to the deepest leaf)."""
return self.root.height
@property
def size(self) -> int:
+ """Get the size of the tree (number of nodes)."""
return self.root.size
def render(self) -> str:
+ """Render the tree."""
style = AbstractStyle("\u2502 ", "\u251c\u2500\u2500", "\u2514\u2500\u2500")
def _str(n: Node):
@@ -248,6 +260,7 @@ class Tree:
return RenderTree(self.root, style=style).by_attr(_str)
def describe(self) -> str:
+ """Describe some important properties of the tree."""
leaf_sizes = [len(leaf.cfgs) for leaf in self.leaves]
leafs: List[int] = sum(([size] * size for size in leaf_sizes), [])
return "\n".join(
diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py
index 89a9f61..234732d 100644
--- a/test/sca/test_tree.py
+++ b/test/sca/test_tree.py
@@ -1,5 +1,4 @@
import random
-import time
from copy import deepcopy
from pyecsca.sca.re.tree import Tree, Map
@@ -31,12 +30,16 @@ def test_map_merge():
binary_sets = {"c": {1, 2}, "d": {2, 4, 3}}
dmap2 = Map.from_sets(cfgs, binary_sets)
assert len(dmap2.mapping) == 2
- dmap1.merge(dmap2)
- assert len(dmap1.mapping) == 4
- assert len(dmap1.cfg_map) == 4
- assert len(dmap1.codomain) == 2
- assert not dmap1["c", 3]
- assert dmap1["a", 1]
+ merged = deepcopy(dmap1)
+ merged.merge(dmap2)
+ assert len(merged.mapping) == 4
+ assert len(merged.cfg_map) == 4
+ assert len(merged.codomain) == 2
+ for i in [1, 2, 3, 4]:
+ for cfg in "ab":
+ assert merged[cfg, i] == dmap1[cfg, i]
+ for cfg in "cd":
+ assert merged[cfg, i] == dmap2[cfg, i]
def test_map_deduplicate():