aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/sca/re/tree.py42
-rw-r--r--test/sca/test_tree.py10
2 files changed, 46 insertions, 6 deletions
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py
index f5d5b61..1681b82 100644
--- a/pyecsca/sca/re/tree.py
+++ b/pyecsca/sca/re/tree.py
@@ -44,7 +44,7 @@ Here we grow the trees.
"""
from math import ceil
from copy import deepcopy
-from typing import Mapping, Any, Set, List, Tuple, Optional
+from typing import Mapping, Any, Set, List, Tuple, Optional, Dict
import numpy as np
import pandas as pd
@@ -117,11 +117,41 @@ class Map:
self.codomain = codomain
@classmethod
- def from_sets(cls, cfgs: Set[Any], mapping: Mapping[Any, Set[Any]]):
- cfgs_l = list(cfgs)
- cfg_map = pd.DataFrame(list(range(len(cfgs_l))), index=cfgs_l, columns=["vals"])
- inputs_l = list(set().union(*mapping.values()))
- data = [[elem in mapping[cfg] for elem in inputs_l] for cfg in cfgs_l]
+ def from_sets(
+ cls, cfgs: Set[Any], mapping: Mapping[Any, Set[Any]], deduplicate: bool = False
+ ):
+ if deduplicate:
+ hash2cfg: Dict[int, Set[Any]] = {}
+ hash2val: Dict[int, Set[Any]] = {}
+ inputs = set()
+ for cfg, val in mapping.items():
+ inputs.update(val)
+ # TODO: Note this may cause collisions?
+ h = hash(tuple(sorted(map(hash, val))))
+ if hash2val.setdefault(h, val) != val:
+ raise ValueError("Collision in dedup!")
+ hcfgs = hash2cfg.setdefault(h, set())
+ hcfgs.add(cfg)
+ cfgs_l: List[Any] = []
+ cfgs_i = []
+ cfgs_vals = []
+ for i, (h, hcfgs) in enumerate(hash2cfg.items()):
+ cfgs_l.extend(hcfgs)
+ cfgs_i.extend([i] * len(hcfgs))
+ cfgs_vals.append(hash2val[h])
+ cfg_map = pd.DataFrame(cfgs_i, index=cfgs_l, columns=["vals"])
+ inputs_l = list(inputs)
+ data = [[elem in val for elem in inputs_l] for val in cfgs_vals]
+ else:
+ cfgs_l = list(cfgs)
+ cfg_map = pd.DataFrame(
+ list(range(len(cfgs_l))), index=cfgs_l, columns=["vals"]
+ )
+ inputs = set()
+ for val in mapping.values():
+ inputs.update(val)
+ inputs_l = list(inputs)
+ data = [[elem in mapping[cfg] for elem in inputs_l] for cfg in cfgs_l]
return Map(pd.DataFrame(data), cfg_map, inputs_l, {True, False})
@classmethod
diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py
index 5db083c..f3e87cd 100644
--- a/test/sca/test_tree.py
+++ b/test/sca/test_tree.py
@@ -61,6 +61,13 @@ def test_map_deduplicate():
for i in [1, 2, 3, 4]:
assert dmap[cfg, i] == original[cfg, i]
assert len(dmap.mapping) < len(original.mapping)
+ assert dmap.cfgs == original.cfgs
+
+ dedupped = Map.from_sets(cfgs, binary_sets, deduplicate=True)
+ for cfg in cfgs:
+ for i in [1, 2, 3, 4]:
+ assert dedupped[cfg, i] == original[cfg, i]
+ assert dedupped.cfgs == original.cfgs
def test_map_with_callable(secp128r1):
@@ -104,12 +111,15 @@ def test_build_tree_dedup():
"g": {4, 2},
}
dmap = Map.from_sets(cfgs, binary_sets)
+ deduplicated = Map.from_sets(cfgs, binary_sets, deduplicate=True)
original = deepcopy(dmap)
dmap.deduplicate()
tree = Tree.build(cfgs, original)
dedup = Tree.build(cfgs, dmap)
+ dedup_other = Tree.build(cfgs, deduplicated)
assert tree.describe() == dedup.describe()
+ assert tree.describe() == dedup_other.describe()
def test_build_tree_reorder():