aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2024-06-04 16:18:20 +0200
committerJ08nY2024-06-04 16:18:20 +0200
commit4eadcd6ad1e4cadcb8bb0b6da8d9c0b62f2a09f0 (patch)
tree8bcb8c253f6b0560e077c1731fda86c0d602eae7
parent46893603bc9ea7b238f160437e4863564bca2f70 (diff)
downloadpyecsca-4eadcd6ad1e4cadcb8bb0b6da8d9c0b62f2a09f0.tar.gz
pyecsca-4eadcd6ad1e4cadcb8bb0b6da8d9c0b62f2a09f0.tar.zst
pyecsca-4eadcd6ad1e4cadcb8bb0b6da8d9c0b62f2a09f0.zip
Improve execution tree API.
-rw-r--r--pyecsca/ec/context.py196
-rw-r--r--pyecsca/ec/formula/metrics.py8
-rw-r--r--pyecsca/sca/attack/CPA.py2
-rw-r--r--pyecsca/sca/attack/DPA.py2
-rw-r--r--pyecsca/sca/target/leakage.py2
-rw-r--r--test/ec/test_context.py72
6 files changed, 135 insertions, 147 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py
index 5b30460..97a64ca 100644
--- a/pyecsca/ec/context.py
+++ b/pyecsca/ec/context.py
@@ -11,11 +11,11 @@ A :py:class:`PathContext` works like a :py:class:`DefaultContext` that only trac
in the tree.
"""
from abc import abstractmethod, ABC
-from collections import OrderedDict
from copy import deepcopy
from typing import List, Optional, ContextManager, Any, Tuple, Sequence, Callable
from public import public
+from anytree import RenderTree, NodeMixin, AbstractStyle, PostOrderIter
@public
@@ -47,6 +47,9 @@ class Action:
current.exit_action(self)
self.inside = False
+ def __repr__(self):
+ return "Action()"
+
@public
class ResultAction(Action):
@@ -79,125 +82,99 @@ class ResultAction(Action):
def __exit__(self, exc_type, exc_val, exc_tb):
if (
- not self._has_result
- and exc_type is None
- and exc_val is None
- and exc_tb is None
+ not self._has_result
+ and exc_type is None
+ and exc_val is None
+ and exc_tb is None
):
raise RuntimeError("Result unset on action exit")
super().__exit__(exc_type, exc_val, exc_tb)
+ def __repr__(self):
+ return f"ResultAction(result={self._result!r})"
+
@public
-class Tree(OrderedDict):
- """
- A recursively-implemented tree.
+class Node(NodeMixin):
+ """A node in an execution tree."""
- >>> tree = Tree()
- >>> tree["a"] = Tree()
- >>> tree["a"]["1"] = Tree()
- >>> tree["a"]["2"] = Tree()
- >>> tree # doctest: +NORMALIZE_WHITESPACE
- a
- 1
- 2
- <BLANKLINE>
- """
+ action: Action
- def get_by_key(self, path: List) -> Any:
+ def __init__(self, action: Action, parent=None, children=None):
+ self.action = action
+ self.parent = parent
+ if children:
+ self.children = children
+
+ def get_by_key(self, path: List[Action]) -> "Node":
"""
- Get the value in the tree at a position given by the path.
+ Get a Node from the tree by a path of :py:class:`Action` s.
- >>> one = Tree()
- >>> tree = Tree()
- >>> tree["a"] = Tree()
- >>> tree["a"]["1"] = Tree()
- >>> tree["a"]["2"] = one
- >>> tree.get_by_key(["a", "2"]) == one
+ >>> tree = Node(Action())
+ >>> a_a = Action()
+ >>> a = Node(a_a, parent=tree)
+ >>> one_a = Action()
+ >>> one = Node(one_a, parent=a)
+ >>> other_a = Action()
+ >>> other = Node(other_a, parent=a)
+ >>> tree.get_by_key([]) == tree
+ True
+ >>> tree.get_by_key([a_a]) == a
+ True
+ >>> tree.get_by_key(([a_a, one_a])) == one
True
- :param path: The path to get.
- :return: The value in the tree.
+ :param path: The path of actions to walk.
+ :return: The node.
"""
if len(path) == 0:
return self
- value = self[path[0]]
- if len(path) == 1:
- return value
- elif isinstance(value, Tree):
- return value.get_by_key(path[1:])
- else:
- raise ValueError
+ for child in self.children:
+ if path[0] == child.action:
+ return child.get_by_key(path[1:])
+ raise ValueError("Path not found.")
- def get_by_index(self, path: List[int]) -> Tuple[Any, Any]:
+ def get_by_index(self, path: List[int]) -> "Node":
"""
- Get the key and value in the tree at a position given by the path of indices.
-
- The nodes inside a level of a tree are ordered by insertion order.
+ Get a Node from the tree by a path of indices.
- >>> one = Tree()
- >>> tree = Tree()
- >>> tree["a"] = Tree()
- >>> tree["a"]["1"] = Tree()
- >>> tree["a"]["2"] = one
- >>> key, value = tree.get_by_index([0, 1])
- >>> key
- '2'
- >>> value == one
+ >>> tree = Node(Action())
+ >>> a_a = Action()
+ >>> a = Node(a_a, parent=tree)
+ >>> one_a = Action()
+ >>> one = Node(one_a, parent=a)
+ >>> other_a = Action()
+ >>> other = Node(other_a, parent=a)
+ >>> tree.get_by_index([]) == tree
+ True
+ >>> tree.get_by_index([0]) == a
+ True
+ >>> tree.get_by_index(([0, 0])) == one
True
- :param path: The path to get.
- :return: The key and value.
+ :param path: The path of indices.
+ :return: The node.
"""
if len(path) == 0:
- raise ValueError
- key = list(self.keys())[path[0]]
- value = self[key]
- if len(path) == 1:
- return key, value
- elif isinstance(value, Tree):
- return value.get_by_index(path[1:])
- else:
- raise ValueError
-
- def repr(self, depth: int = 0) -> str:
- """
- Construct a textual representation of the tree. Useful for visualization and debugging.
+ return self
+ return self.children[path[0]].get_by_index(path[1:])
- :param depth:
- :return: The resulting textual representation.
+ def walk(self, callback: Callable[[Action], None]):
"""
- result = ""
- for key, value in self.items():
- if isinstance(value, Tree):
- result += "\t" * depth + str(key) + "\n"
- result += value.repr(depth + 1)
- else:
- result += "\t" * depth + str(key) + ":" + str(value) + "\n"
- return result
+ Walk the tree in post-order (as it was executed) and apply :paramref:`callback`.
- def walk(self, callback: Callable[[Any], None]) -> None:
+ :param callback: The callback to apply to the actions in the nodes.
"""
- Walk the tree, depth-first, with the callback.
+ for node in PostOrderIter(self):
+ callback(node.action)
- >>> tree = Tree()
- >>> tree["a"] = Tree()
- >>> tree["a"]["1"] = Tree()
- >>> tree["a"]["2"] = Tree()
- >>> tree.walk(lambda key: print(key))
- a
- 1
- 2
+ def render(self) -> str:
+ """Render the tree."""
+ style = AbstractStyle("\u2502 ", "\u251c\u2500\u2500", "\u2514\u2500\u2500")
+ return RenderTree(self, style=style).by_attr(lambda node: node.action)
- :param callback: The callback to call for all values in the tree.
- """
- for key, val in self.items():
- callback(key)
- if isinstance(val, Tree):
- val.walk(callback)
-
- def __repr__(self):
- return self.repr()
+ def __str__(self):
+ return self.render()
@public
@@ -243,36 +220,37 @@ class DefaultContext(Context):
... r = other_action.exit("some result")
... with Action() as yet_another:
... pass
- >>> ctx.actions # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
- <...Action ...
- <...ResultAction ...
- <...Action ...
- <BLANKLINE>
- >>> root, subtree = ctx.actions.get_by_index([0])
- >>> for action in subtree: # doctest: +ELLIPSIS
- ... print(action)
- <...ResultAction ...
- <...Action ...
+ >>> print(ctx.actions) # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
+ Action()
+ ├──ResultAction(result='some result')
+ └──Action()
+ >>> for other in ctx.actions.children: # doctest: +ELLIPSIS
+ ... print(other.action)
+ ResultAction(result='some result')
+ Action()
"""
- actions: Tree
+ actions: Optional[Node]
current: List[Action]
def __init__(self):
- self.actions = Tree()
+ self.actions = None
self.current = []
def enter_action(self, action: Action) -> None:
- self.actions.get_by_key(self.current)[action] = Tree()
+ if self.actions is None:
+ self.actions = Node(action)
+ else:
+ Node(action, parent=self.actions.get_by_key(self.current[1:]))
self.current.append(action)
def exit_action(self, action: Action) -> None:
if len(self.current) < 1 or self.current[-1] != action:
- raise ValueError
+ raise ValueError("Cannot exit, not in an action.")
self.current.pop()
def __repr__(self):
- return f"{self.__class__.__name__}({self.actions!r}, current={self.current!r})"
+ return f"{self.__class__.__name__}(actions={self.actions.size if self.actions else 0}, current={self.current!r})"
@public
@@ -282,7 +260,7 @@ class PathContext(Context):
path: List[int]
current: List[int]
current_depth: int
- value: Any
+ value: Optional[Action]
def __init__(self, path: Sequence[int]):
"""
@@ -344,7 +322,7 @@ def local(ctx: Optional[Context] = None) -> ContextManager:
>>> with local(DefaultContext()) as ctx:
... with Action() as action:
... pass
- >>> list(ctx.actions)[0] == action
+ >>> ctx.actions.action == action
True
:param ctx: If none, current context is copied.
diff --git a/pyecsca/ec/formula/metrics.py b/pyecsca/ec/formula/metrics.py
index f8dbb73..063a62d 100644
--- a/pyecsca/ec/formula/metrics.py
+++ b/pyecsca/ec/formula/metrics.py
@@ -103,15 +103,15 @@ def formula_similarity_fuzz(
inputs = (P, Q, R)[: one.num_inputs]
with local(DefaultContext()) as ctx:
res_one = one(curve.prime, *inputs, **curve.parameters)
- action_one = ctx.actions.get_by_index([0])
+ action_one = ctx.actions.action
ivs_one = set(
- map(attrgetter("value"), sum(action_one[0].intermediates.values(), []))
+ map(attrgetter("value"), sum(action_one.intermediates.values(), []))
)
with local(DefaultContext()) as ctx:
res_other = other(curve.prime, *inputs, **curve.parameters)
- action_other = ctx.actions.get_by_index([0])
+ action_other = ctx.actions.action
ivs_other = set(
- map(attrgetter("value"), sum(action_other[0].intermediates.values(), []))
+ map(attrgetter("value"), sum(action_other.intermediates.values(), []))
)
iv_matches += len(ivs_one.intersection(ivs_other)) / max(
len(ivs_one), len(ivs_other)
diff --git a/pyecsca/sca/attack/CPA.py b/pyecsca/sca/attack/CPA.py
index 77c9abb..6efa3b5 100644
--- a/pyecsca/sca/attack/CPA.py
+++ b/pyecsca/sca/attack/CPA.py
@@ -55,7 +55,7 @@ class CPA:
action_index += 2
elif bit == "0":
action_index += 1
- result = ctx.actions.get_by_index([0, action_index])[0]
+ result = ctx.actions.get_by_index([action_index]).action
return result.output_points[0].X
def compute_correlation_trace(
diff --git a/pyecsca/sca/attack/DPA.py b/pyecsca/sca/attack/DPA.py
index 661adc6..da4ae65 100644
--- a/pyecsca/sca/attack/DPA.py
+++ b/pyecsca/sca/attack/DPA.py
@@ -50,7 +50,7 @@ class DPA:
action_index += 2
elif bit == "0":
action_index += 1
- result = ctx.actions.get_by_index([0, action_index])[0]
+ result = ctx.actions.get_by_index([action_index]).action
return result.output_points[0]
def split_traces(
diff --git a/pyecsca/sca/target/leakage.py b/pyecsca/sca/target/leakage.py
index c832028..e8734fb 100644
--- a/pyecsca/sca/target/leakage.py
+++ b/pyecsca/sca/target/leakage.py
@@ -52,6 +52,8 @@ class LeakageTarget(Target):
temp_trace.append(leak)
temp_trace: list[int] = []
+ if not context.actions:
+ raise ValueError("Empty context")
context.actions.walk(callback)
return Trace(np.array(temp_trace))
diff --git a/test/ec/test_context.py b/test/ec/test_context.py
index 9ec7962..65053d4 100644
--- a/test/ec/test_context.py
+++ b/test/ec/test_context.py
@@ -3,8 +3,8 @@ import pytest
from pyecsca.ec.context import (
local,
DefaultContext,
- Tree,
- PathContext
+ Node,
+ PathContext, Action
)
from pyecsca.ec.key_generation import KeyGeneration
from pyecsca.ec.mod import RandomModAction
@@ -12,40 +12,49 @@ from pyecsca.ec.mult import LTRMultiplier, ScalarMultiplicationAction
def test_walk_by_key():
- tree = Tree()
- tree["a"] = Tree()
- tree["a"]["1"] = Tree()
- tree["a"]["2"] = Tree()
- assert "a" in tree
- assert isinstance(tree.get_by_key([]), Tree)
- assert isinstance(tree.get_by_key(["a"]), Tree)
- assert isinstance(tree.get_by_key(["a", "1"]), Tree)
+ tree = Node(Action())
+ a_a = Action()
+ a = Node(a_a, parent=tree)
+ one_a = Action()
+ one = Node(one_a, parent=a)
+ other_a = Action()
+ other = Node(other_a, parent=a)
+
+ assert tree.get_by_key([]) == tree
+ assert tree.get_by_key([a_a]) == a
+ assert tree.get_by_key(([a_a, one_a])) == one
+ assert tree.get_by_key(([a_a, other_a])) == other
def test_walk_by_index():
- tree = Tree()
- a = Tree()
- tree["a"] = a
- d = Tree()
- b = Tree()
- tree["a"]["d"] = d
- tree["a"]["b"] = b
- assert "a" in tree
- with pytest.raises(ValueError):
- tree.get_by_index([])
+ tree = Node(Action())
+ a_a = Action()
+ a = Node(a_a, parent=tree)
+ one_a = Action()
+ one = Node(one_a, parent=a)
+ other_a = Action()
+ other = Node(other_a, parent=a)
+
+ assert tree.get_by_index([]) == tree
+ assert tree.get_by_index([0]) == a
+ assert tree.get_by_index([0, 0]) == one
+ assert tree.get_by_index([0, 1]) == other
- assert tree.get_by_index([0]) == ("a", a)
- assert tree.get_by_index([0, 0]) == ("d", d)
+def test_render():
+ tree = Node(Action())
+ a_a = Action()
+ a = Node(a_a, parent=tree)
+ one_a = Action()
+ Node(one_a, parent=a)
+ other_a = Action()
+ Node(other_a, parent=a)
-def test_repr():
- tree = Tree()
- tree["a"] = Tree()
- tree["a"]["1"] = Tree()
- tree["a"]["2"] = Tree()
- txt = tree.repr()
- assert txt.count("\t") == 2
- assert txt.count("\n") == 3
+ txt = tree.render()
+ assert txt == """Action()
+└──Action()
+ ├──Action()
+ └──Action()"""
@pytest.fixture()
@@ -71,8 +80,7 @@ def test_null(mult):
def test_default(mult):
with local(DefaultContext()) as ctx:
result = mult.multiply(59)
- assert len(ctx.actions) == 1
- action = next(iter(ctx.actions.keys()))
+ action = ctx.actions.action
assert isinstance(action, ScalarMultiplicationAction)
assert result == action.result