aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2024-01-18 14:12:12 +0100
committerJ08nY2024-01-18 14:12:12 +0100
commite586e5604a953ab53fefc7efea804e9fb43b28b3 (patch)
tree915da87566fdc1f9c121bcca130d153d242dcb2e
parentd2b52349753c1ed940abecbc6fe3f561fc2fd8fe (diff)
downloadpyecsca-e586e5604a953ab53fefc7efea804e9fb43b28b3.tar.gz
pyecsca-e586e5604a953ab53fefc7efea804e9fb43b28b3.tar.zst
pyecsca-e586e5604a953ab53fefc7efea804e9fb43b28b3.zip
Make tree building more general.
-rw-r--r--pyecsca/sca/re/rpa.py4
-rw-r--r--pyecsca/sca/re/tree.py101
-rw-r--r--test/sca/test_rpa.py2
3 files changed, 70 insertions, 37 deletions
diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py
index 5bbd145..7bd796e 100644
--- a/pyecsca/sca/re/rpa.py
+++ b/pyecsca/sca/re/rpa.py
@@ -263,7 +263,7 @@ def rpa_distinguish(
used |= multiply_multiples
mults_to_multiples[mult] = used
- tree = build_distinguishing_tree(mults_to_multiples)
+ tree = build_distinguishing_tree(set(mults), mults_to_multiples)
log("Built distinguishing tree.")
log(
RenderTree(tree).by_attr(
@@ -277,7 +277,7 @@ def rpa_distinguish(
continue
current_node = tree
while current_node.children:
- best_distinguishing_multiple = current_node.name
+ _, best_distinguishing_multiple = current_node.name
P0_inverse = rpa_input_point(best_distinguishing_multiple, P0, params)
responses = []
for _ in range(majority):
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py
index d3433cc..306893c 100644
--- a/pyecsca/sca/re/tree.py
+++ b/pyecsca/sca/re/tree.py
@@ -1,4 +1,5 @@
-from typing import Mapping, Any, Set
+from copy import deepcopy
+from typing import Mapping, Any, Set, List
from collections import Counter
from public import public
from anytree import Node
@@ -6,58 +7,90 @@ from anytree import Node
@public
def build_distinguishing_tree(
- cfg2resp: Mapping[Any, Set[Any]], **kwargs
+ cfgs: Set[Any], *cfg2resp_maps: Mapping[Any, Set[Any]], **kwargs
) -> Node:
"""
- Build a distinguishing tree for a given mapping of configs to True oracle responses.
+ Build a distinguishing tree for given mappings of configs to True oracle responses.
- :param cfg2resp:
- :param kwargs:
- :return:
+ :param cfgs: The configurations to distinguish.
+ :param cfg2resp_maps: The mappings of configs to sets of inputs that give True oracle responses.
+ :param kwargs: Additional keyword arguments that will be passed to the node.
+ :return: A distinguishing tree.
"""
- n_cfgs = len(cfg2resp)
+ n_cfgs = len(cfgs)
# If there is only one remaining cfg, we do not need to continue and just return (base case 1).
# Note that n_cfgs will never be 0 here, as the base case 2 returns if the cfgs cannot be split into two sets (one would be empty).
if n_cfgs == 1:
- return Node(None, cfgs=list(cfg2resp.keys()), **kwargs)
+ return Node(None, cfgs=set(cfgs), **kwargs)
- counts: Counter = Counter()
- for elements in cfg2resp.values(): # Elements of the set need to be hashable
- counts.update(elements)
+ # Go over the maps and have a counter for each one
+ counters: List[Counter] = [Counter() for cfg2resp in cfg2resp_maps]
+ for counter, cfg2resp in zip(counters, cfg2resp_maps):
+ for cfg in cfgs:
+ elements = cfg2resp[cfg]
+ counter.update(elements)
nhalf = n_cfgs / 2
+ best_distinguishing_map = None
best_distinguishing_element = None
best_count = None
best_nhalf_distance = None
- for multiple, count in counts.items():
- if (
- best_distinguishing_element is None
- or abs(count - nhalf) < best_nhalf_distance
- ):
- best_distinguishing_element = multiple
- best_count = count
- best_nhalf_distance = abs(count - nhalf)
+ for i, (counter, cfg2resp) in enumerate(zip(counters, cfg2resp_maps)):
+ for multiple, count in counter.items():
+ if (
+ best_distinguishing_element is None
+ or abs(count - nhalf) < best_nhalf_distance
+ ):
+ best_distinguishing_map = i
+ best_distinguishing_element = multiple
+ best_count = count
+ best_nhalf_distance = abs(count - nhalf)
# We found nothing distinguishing the configs, so return them all (base case 2).
- if best_count in (0, n_cfgs, None):
- return Node(None, cfgs=list(cfg2resp.keys()), **kwargs)
+ if best_count in (0, n_cfgs, None) or best_distinguishing_map is None:
+ return Node(None, cfgs=set(cfgs), **kwargs)
result = Node(
- best_distinguishing_element, cfgs=list(cfg2resp.keys()), **kwargs
+ (best_distinguishing_map, best_distinguishing_element), cfgs=set(cfgs), **kwargs
)
# Now go deeper and split based on the best-distinguishing element.
- true_cfgs = {
- mult: elements
- for mult, elements in cfg2resp.items()
- if best_distinguishing_element in elements
- }
- true_child = build_distinguishing_tree(true_cfgs, oracle_response=True)
+ true_cfgs = set(
+ cfg
+ for cfg in cfgs
+ if best_distinguishing_element in cfg2resp_maps[best_distinguishing_map][cfg]
+ )
+ true_restricted_cfg2resps = [
+ {cfg: cfg2resp[cfg] for cfg in true_cfgs} for cfg2resp in cfg2resp_maps
+ ]
+ true_child = build_distinguishing_tree(
+ true_cfgs, *true_restricted_cfg2resps, oracle_response=True
+ )
true_child.parent = result
- false_cfgs = {
- mult: elements
- for mult, elements in cfg2resp.items()
- if best_distinguishing_element not in elements
- }
- false_child = build_distinguishing_tree(false_cfgs, oracle_response=False)
+
+ false_cfgs = cfgs.difference(true_cfgs)
+ false_restricted_cfg2resps = [
+ {cfg: cfg2resp[cfg] for cfg in false_cfgs} for cfg2resp in cfg2resp_maps
+ ]
+ false_child = build_distinguishing_tree(
+ false_cfgs, *false_restricted_cfg2resps, oracle_response=False
+ )
false_child.parent = result
return result
+
+
+def expand_tree(tree: Node, cfg2resp: Mapping[Any, Set[Any]]) -> Node:
+ """
+ Attempt to expand a given distinguishing tree with a new mapping of configs to True oracle responses.
+
+ :param tree: The tree to expand (will be copied).
+ :param cfg2resp: The new map.
+ :return: The expanded tree.
+ """
+ tree = deepcopy(tree)
+ for leaf in tree.leaves:
+ expanded = build_distinguishing_tree(leaf.cfgs, cfg2resp)
+ # If we were able to split the leaf further, then replace it with the found tree.
+ if not expanded.is_leaf:
+ leaf.name = expanded.name
+ leaf.children = expanded.children
+ return tree
diff --git a/test/sca/test_rpa.py b/test/sca/test_rpa.py
index 39c281d..314bc4a 100644
--- a/test/sca/test_rpa.py
+++ b/test/sca/test_rpa.py
@@ -157,4 +157,4 @@ def test_distinguish(secp128r1, add, dbl, neg):
with redirect_stdout(io.StringIO()):
result = rpa_distinguish(secp128r1, multipliers, simulated_oracle)
assert 1 == len(result)
- assert real_mult == result[0]
+ assert real_mult == result.pop()