diff options
| -rw-r--r-- | pyecsca/sca/re/tree.py | 42 | ||||
| -rw-r--r-- | test/sca/test_tree.py | 10 |
2 files changed, 46 insertions, 6 deletions
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py index f5d5b61..1681b82 100644 --- a/pyecsca/sca/re/tree.py +++ b/pyecsca/sca/re/tree.py @@ -44,7 +44,7 @@ Here we grow the trees. """ from math import ceil from copy import deepcopy -from typing import Mapping, Any, Set, List, Tuple, Optional +from typing import Mapping, Any, Set, List, Tuple, Optional, Dict import numpy as np import pandas as pd @@ -117,11 +117,41 @@ class Map: self.codomain = codomain @classmethod - def from_sets(cls, cfgs: Set[Any], mapping: Mapping[Any, Set[Any]]): - cfgs_l = list(cfgs) - cfg_map = pd.DataFrame(list(range(len(cfgs_l))), index=cfgs_l, columns=["vals"]) - inputs_l = list(set().union(*mapping.values())) - data = [[elem in mapping[cfg] for elem in inputs_l] for cfg in cfgs_l] + def from_sets( + cls, cfgs: Set[Any], mapping: Mapping[Any, Set[Any]], deduplicate: bool = False + ): + if deduplicate: + hash2cfg: Dict[int, Set[Any]] = {} + hash2val: Dict[int, Set[Any]] = {} + inputs = set() + for cfg, val in mapping.items(): + inputs.update(val) + # TODO: Note this may cause collisions? + h = hash(tuple(sorted(map(hash, val)))) + if hash2val.setdefault(h, val) != val: + raise ValueError("Collision in dedup!") + hcfgs = hash2cfg.setdefault(h, set()) + hcfgs.add(cfg) + cfgs_l: List[Any] = [] + cfgs_i = [] + cfgs_vals = [] + for i, (h, hcfgs) in enumerate(hash2cfg.items()): + cfgs_l.extend(hcfgs) + cfgs_i.extend([i] * len(hcfgs)) + cfgs_vals.append(hash2val[h]) + cfg_map = pd.DataFrame(cfgs_i, index=cfgs_l, columns=["vals"]) + inputs_l = list(inputs) + data = [[elem in val for elem in inputs_l] for val in cfgs_vals] + else: + cfgs_l = list(cfgs) + cfg_map = pd.DataFrame( + list(range(len(cfgs_l))), index=cfgs_l, columns=["vals"] + ) + inputs = set() + for val in mapping.values(): + inputs.update(val) + inputs_l = list(inputs) + data = [[elem in mapping[cfg] for elem in inputs_l] for cfg in cfgs_l] return Map(pd.DataFrame(data), cfg_map, inputs_l, {True, False}) @classmethod diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py index 5db083c..f3e87cd 100644 --- a/test/sca/test_tree.py +++ b/test/sca/test_tree.py @@ -61,6 +61,13 @@ def test_map_deduplicate(): for i in [1, 2, 3, 4]: assert dmap[cfg, i] == original[cfg, i] assert len(dmap.mapping) < len(original.mapping) + assert dmap.cfgs == original.cfgs + + dedupped = Map.from_sets(cfgs, binary_sets, deduplicate=True) + for cfg in cfgs: + for i in [1, 2, 3, 4]: + assert dedupped[cfg, i] == original[cfg, i] + assert dedupped.cfgs == original.cfgs def test_map_with_callable(secp128r1): @@ -104,12 +111,15 @@ def test_build_tree_dedup(): "g": {4, 2}, } dmap = Map.from_sets(cfgs, binary_sets) + deduplicated = Map.from_sets(cfgs, binary_sets, deduplicate=True) original = deepcopy(dmap) dmap.deduplicate() tree = Tree.build(cfgs, original) dedup = Tree.build(cfgs, dmap) + dedup_other = Tree.build(cfgs, deduplicated) assert tree.describe() == dedup.describe() + assert tree.describe() == dedup_other.describe() def test_build_tree_reorder(): |
