diff options
| m--------- | notebook | 0 | ||||
| -rw-r--r-- | pyecsca/sca/re/base.py | 23 | ||||
| -rw-r--r-- | pyecsca/sca/re/rpa.py | 196 | ||||
| -rw-r--r-- | pyecsca/sca/re/tree.py | 7 | ||||
| -rw-r--r-- | test/sca/test_rpa.py | 4 |
5 files changed, 148 insertions, 82 deletions
diff --git a/notebook b/notebook -Subproject d814afb34f044c6ea1486e17a32c8d4691d9b6a +Subproject 57cadf3ea4ae80e999a2f18ab605b633cd78cbe diff --git a/pyecsca/sca/re/base.py b/pyecsca/sca/re/base.py new file mode 100644 index 0000000..1340a68 --- /dev/null +++ b/pyecsca/sca/re/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod, ABC +from typing import Optional, Any, Set + +from public import public + +from .tree import Tree + + +@public +class RE(ABC): + tree: Optional[Tree] = None + configs: Set[Any] + + def __init__(self, configs: Set[Any]): + self.configs = configs + + @abstractmethod + def build_tree(self, *args, **kwargs): + pass + + @abstractmethod + def run(self, *args, **kwargs): + pass diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py index dbf9bcf..266602b 100644 --- a/pyecsca/sca/re/rpa.py +++ b/pyecsca/sca/re/rpa.py @@ -3,12 +3,12 @@ Provides functionality inspired by the Refined-Power Analysis attack by Goubin [ """ from copy import copy, deepcopy -from anytree import RenderTree from public import public -from typing import MutableMapping, Optional, Callable, List, Set +from typing import MutableMapping, Optional, Callable, List, Set, cast from sympy import FF, sympify, Poly, symbols +from .base import RE from .tree import Tree, Map from ...ec.coordinates import AffineCoordinateModel from ...ec.formula import ( @@ -75,6 +75,7 @@ class MultipleContext(Context): if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)): self.inside = False if isinstance(action, FormulaAction) and self.inside: + action = cast(FormulaAction, action) if isinstance(action.formula, DoublingFormula): inp = action.input_points[0] out = action.output_points[0] @@ -181,6 +182,7 @@ def rpa_distinguish( multipliers: List[ScalarMultiplier], oracle: Callable[[int, Point], bool], bound: Optional[int] = None, + tries: int = 10, majority: int = 1, use_init: bool = True, use_multiply: bool = True, @@ -193,89 +195,134 @@ def rpa_distinguish( :param multipliers: The list of possible multipliers. :param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar. :param bound: A bound on the size of the scalar to consider. + :param tries: Number of tries to get a non-trivial tree. :param majority: Query the oracle up to `majority` times and take the majority vote of the results. :param use_init: Whether to consider the point multiples that happen in scalarmult initialization. :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization). :return: The list of possible multipliers after distinguishing (ideally just one). """ - if (majority % 2) == 0: - raise ValueError("Cannot use even majority.") - if not (use_init or use_multiply): - raise ValueError("Has to use either init or multiply or both.") - P0 = rpa_point_0y(params) - if not P0: - raise ValueError("There are no RPA-points on the provided curve.") - log(f"Got RPA point {P0}.") - if not bound: - bound = params.order + re = RPA(set(multipliers)) + re.build_tree(params, tries, bound, use_init, use_multiply) + return re.run(oracle, majority) - mults = set(copy(mult) for mult in multipliers) - init_contexts = {} - for mult in mults: - with local(MultipleContext()) as ctx: - mult.init(params, params.generator) - init_contexts[mult] = ctx - tries = 0 - while True: - if tries > 10: - warn("Tried more than 10 times. Aborting.") - return mults - scalar = int(Mod.random(bound)) - log(f"Got scalar {scalar}") - log([mult.__class__.__name__ for mult in mults]) - mults_to_multiples = {} +@public +class RPA(RE): + params: Optional[DomainParameters] = None + P0: Optional[Point] = None + scalars: Optional[List[int]] = None + + def build_tree( + self, + params: DomainParameters, + tries: int = 10, + bound: Optional[int] = None, + use_init: bool = True, + use_multiply: bool = True, + ): + if not (use_init or use_multiply): + raise ValueError("Has to use either init or multiply or both.") + P0 = rpa_point_0y(params) + if not P0: + raise ValueError("There are no RPA-points on the provided curve.") + if not bound: + bound = params.order + + mults = set(copy(mult) for mult in self.configs) + init_contexts = {} for mult in mults: - # Copy the context after init to not accumulate multiples by accident here. - init_context = deepcopy(init_contexts[mult]) - # Take the computed points during init - init_points = set(init_context.parents.keys()) - # And get their parents (inputs to formulas) - init_parents = set( - sum((init_context.parents[point] for point in init_points), []) - ) - # Go over the parents and map them to multiples of the base (plus-minus sign) - init_multiples = set( - map( - lambda v: Mod(v, params.order), - (init_context.points[parent] for parent in init_parents), + with local(MultipleContext()) as ctx: + mult.init(params, params.generator) + init_contexts[mult] = ctx + + done = 0 + tree = None + scalars = [] + while True: + scalar = int(Mod.random(bound)) + log(f"Got scalar {scalar}") + log([mult.__class__.__name__ for mult in mults]) + mults_to_multiples = {} + for mult in mults: + # Copy the context after init to not accumulate multiples by accident here. + init_context = deepcopy(init_contexts[mult]) + # Take the computed points during init + init_points = set(init_context.parents.keys()) + # And get their parents (inputs to formulas) + init_parents = set( + sum((init_context.parents[point] for point in init_points), []) ) - ) - init_multiples |= set(map(lambda v: -v, init_multiples)) - # Now do the multiply and repeat the above, but only consider new computed points - with local(init_context) as ctx: - mult.multiply(scalar) - all_points = set(ctx.parents.keys()) - multiply_parents = set( - sum((ctx.parents[point] for point in all_points - init_points), []) - ) - multiply_multiples = set( - map( - lambda v: Mod(v, params.order), - (ctx.points[parent] for parent in multiply_parents), + # Go over the parents and map them to multiples of the base (plus-minus sign) + init_multiples = set( + map( + lambda v: Mod(v, params.order), + (init_context.points[parent] for parent in init_parents), + ) ) - ) - multiply_multiples |= set(map(lambda v: -v, multiply_multiples)) - used = set() - if use_init: - used |= init_multiples - if use_multiply: - used |= multiply_multiples - mults_to_multiples[mult] = used + init_multiples |= set(map(lambda v: -v, init_multiples)) + # Now do the multiply and repeat the above, but only consider new computed points + with local(init_context) as ctx: + mult.multiply(scalar) + all_points = set(ctx.parents.keys()) + multiply_parents = set( + sum((ctx.parents[point] for point in all_points - init_points), []) + ) + multiply_multiples = set( + map( + lambda v: Mod(v, params.order), + (ctx.points[parent] for parent in multiply_parents), + ) + ) + multiply_multiples |= set(map(lambda v: -v, multiply_multiples)) + used = set() + if use_init: + used |= init_multiples + if use_multiply: + used |= multiply_multiples + mults_to_multiples[mult] = used + + dmap = Map.from_sets(set(mults), mults_to_multiples) + if tree is None: + tree = Tree.build(set(mults), dmap) + else: + tree = tree.expand(dmap) + + log("Built distinguishing tree.") + log(tree.render()) + scalars.append(scalar) + if not tree.precise: + done += 1 + if done > tries: + warn(f"Tried more than {tries} times. Aborting. Distinguishing may not be precise.") + break + else: + continue + else: + break + self.scalars = scalars + self.tree = tree + self.params = params + self.P0 = P0 - dmap = Map.from_sets(set(mults), mults_to_multiples) + def run( + self, oracle: Callable[[int, Point], bool], majority: int = 1 + ) -> Set[ScalarMultiplier]: + if self.tree is None or self.scalars is None: + raise ValueError("Need to build tree first.") - tree = Tree.build(set(mults), dmap) - log("Built distinguishing tree.") - log(tree.render()) - if tree.size == 1: - tries += 1 - continue - current_node = tree.root + if (majority % 2) == 0: + raise ValueError("Cannot use even majority.") + + current_node = self.tree.root + mults = current_node.cfgs while current_node.children: + scalar = self.scalars[current_node.dmap_index] # type: ignore best_distinguishing_multiple: Mod = current_node.dmap_input # type: ignore - P0_inverse = rpa_input_point(best_distinguishing_multiple, P0, params) + P0_inverse = rpa_input_point( + best_distinguishing_multiple, self.P0, self.params + ) responses = [] + response = True for _ in range(majority): responses.append(oracle(scalar, P0_inverse)) if responses.count(True) > (majority // 2): @@ -285,16 +332,9 @@ def rpa_distinguish( response = False break log(f"Oracle response -> {response}") - for mult in mults: - log( - mult.__class__.__name__, - best_distinguishing_multiple in mults_to_multiples[mult], - ) response_map = {child.response: child for child in current_node.children} current_node = response_map[response] mults = current_node.cfgs log([mult.__class__.__name__ for mult in mults]) log() - - if len(mults) == 1: - return mults + return mults diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py index 2ae7945..4e9cf3e 100644 --- a/pyecsca/sca/re/tree.py +++ b/pyecsca/sca/re/tree.py @@ -199,7 +199,7 @@ class Map: "Rows: {len(self.mapping)}, ({self.mapping.memory_usage(index=True).sum():_} bytes)", "Inputs: {len(self.domain)}", "Codomain: {len(self.codomain)}", - "None in codomain: {None in self.codomain}" + "None in codomain: {None in self.codomain}", ) ) @@ -264,6 +264,11 @@ class Tree: """Get the size of the tree (number of nodes).""" return self.root.size + @property + def precise(self) -> bool: + """Whether the tree is precise (all leaves have only a single configuration).""" + return all(len(leaf.cfgs) == 1 for leaf in self.leaves) + def render(self) -> str: """Render the tree.""" style = AbstractStyle("\u2502 ", "\u251c\u2500\u2500", "\u2514\u2500\u2500") diff --git a/test/sca/test_rpa.py b/test/sca/test_rpa.py index 43f4183..af19ba2 100644 --- a/test/sca/test_rpa.py +++ b/test/sca/test_rpa.py @@ -180,8 +180,6 @@ def test_distinguish(secp128r1, add, dbl, neg): map(lambda P: P.X == 0 or P.Y == 0, sum(ctx.parents.values(), [])) ) - with redirect_stdout(io.StringIO()): - result = rpa_distinguish(secp128r1, multipliers, simulated_oracle) + result = rpa_distinguish(secp128r1, multipliers, simulated_oracle) assert real_mult in result assert 1 == len(result) - assert real_mult == result.pop() |
