aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
m---------notebook0
-rw-r--r--pyecsca/sca/re/base.py23
-rw-r--r--pyecsca/sca/re/rpa.py196
-rw-r--r--pyecsca/sca/re/tree.py7
-rw-r--r--test/sca/test_rpa.py4
5 files changed, 148 insertions, 82 deletions
diff --git a/notebook b/notebook
-Subproject d814afb34f044c6ea1486e17a32c8d4691d9b6a
+Subproject 57cadf3ea4ae80e999a2f18ab605b633cd78cbe
diff --git a/pyecsca/sca/re/base.py b/pyecsca/sca/re/base.py
new file mode 100644
index 0000000..1340a68
--- /dev/null
+++ b/pyecsca/sca/re/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod, ABC
+from typing import Optional, Any, Set
+
+from public import public
+
+from .tree import Tree
+
+
+@public
+class RE(ABC):
+ tree: Optional[Tree] = None
+ configs: Set[Any]
+
+ def __init__(self, configs: Set[Any]):
+ self.configs = configs
+
+ @abstractmethod
+ def build_tree(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def run(self, *args, **kwargs):
+ pass
diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py
index dbf9bcf..266602b 100644
--- a/pyecsca/sca/re/rpa.py
+++ b/pyecsca/sca/re/rpa.py
@@ -3,12 +3,12 @@ Provides functionality inspired by the Refined-Power Analysis attack by Goubin [
"""
from copy import copy, deepcopy
-from anytree import RenderTree
from public import public
-from typing import MutableMapping, Optional, Callable, List, Set
+from typing import MutableMapping, Optional, Callable, List, Set, cast
from sympy import FF, sympify, Poly, symbols
+from .base import RE
from .tree import Tree, Map
from ...ec.coordinates import AffineCoordinateModel
from ...ec.formula import (
@@ -75,6 +75,7 @@ class MultipleContext(Context):
if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)):
self.inside = False
if isinstance(action, FormulaAction) and self.inside:
+ action = cast(FormulaAction, action)
if isinstance(action.formula, DoublingFormula):
inp = action.input_points[0]
out = action.output_points[0]
@@ -181,6 +182,7 @@ def rpa_distinguish(
multipliers: List[ScalarMultiplier],
oracle: Callable[[int, Point], bool],
bound: Optional[int] = None,
+ tries: int = 10,
majority: int = 1,
use_init: bool = True,
use_multiply: bool = True,
@@ -193,89 +195,134 @@ def rpa_distinguish(
:param multipliers: The list of possible multipliers.
:param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar.
:param bound: A bound on the size of the scalar to consider.
+ :param tries: Number of tries to get a non-trivial tree.
:param majority: Query the oracle up to `majority` times and take the majority vote of the results.
:param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
:param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: The list of possible multipliers after distinguishing (ideally just one).
"""
- if (majority % 2) == 0:
- raise ValueError("Cannot use even majority.")
- if not (use_init or use_multiply):
- raise ValueError("Has to use either init or multiply or both.")
- P0 = rpa_point_0y(params)
- if not P0:
- raise ValueError("There are no RPA-points on the provided curve.")
- log(f"Got RPA point {P0}.")
- if not bound:
- bound = params.order
+ re = RPA(set(multipliers))
+ re.build_tree(params, tries, bound, use_init, use_multiply)
+ return re.run(oracle, majority)
- mults = set(copy(mult) for mult in multipliers)
- init_contexts = {}
- for mult in mults:
- with local(MultipleContext()) as ctx:
- mult.init(params, params.generator)
- init_contexts[mult] = ctx
- tries = 0
- while True:
- if tries > 10:
- warn("Tried more than 10 times. Aborting.")
- return mults
- scalar = int(Mod.random(bound))
- log(f"Got scalar {scalar}")
- log([mult.__class__.__name__ for mult in mults])
- mults_to_multiples = {}
+@public
+class RPA(RE):
+ params: Optional[DomainParameters] = None
+ P0: Optional[Point] = None
+ scalars: Optional[List[int]] = None
+
+ def build_tree(
+ self,
+ params: DomainParameters,
+ tries: int = 10,
+ bound: Optional[int] = None,
+ use_init: bool = True,
+ use_multiply: bool = True,
+ ):
+ if not (use_init or use_multiply):
+ raise ValueError("Has to use either init or multiply or both.")
+ P0 = rpa_point_0y(params)
+ if not P0:
+ raise ValueError("There are no RPA-points on the provided curve.")
+ if not bound:
+ bound = params.order
+
+ mults = set(copy(mult) for mult in self.configs)
+ init_contexts = {}
for mult in mults:
- # Copy the context after init to not accumulate multiples by accident here.
- init_context = deepcopy(init_contexts[mult])
- # Take the computed points during init
- init_points = set(init_context.parents.keys())
- # And get their parents (inputs to formulas)
- init_parents = set(
- sum((init_context.parents[point] for point in init_points), [])
- )
- # Go over the parents and map them to multiples of the base (plus-minus sign)
- init_multiples = set(
- map(
- lambda v: Mod(v, params.order),
- (init_context.points[parent] for parent in init_parents),
+ with local(MultipleContext()) as ctx:
+ mult.init(params, params.generator)
+ init_contexts[mult] = ctx
+
+ done = 0
+ tree = None
+ scalars = []
+ while True:
+ scalar = int(Mod.random(bound))
+ log(f"Got scalar {scalar}")
+ log([mult.__class__.__name__ for mult in mults])
+ mults_to_multiples = {}
+ for mult in mults:
+ # Copy the context after init to not accumulate multiples by accident here.
+ init_context = deepcopy(init_contexts[mult])
+ # Take the computed points during init
+ init_points = set(init_context.parents.keys())
+ # And get their parents (inputs to formulas)
+ init_parents = set(
+ sum((init_context.parents[point] for point in init_points), [])
)
- )
- init_multiples |= set(map(lambda v: -v, init_multiples))
- # Now do the multiply and repeat the above, but only consider new computed points
- with local(init_context) as ctx:
- mult.multiply(scalar)
- all_points = set(ctx.parents.keys())
- multiply_parents = set(
- sum((ctx.parents[point] for point in all_points - init_points), [])
- )
- multiply_multiples = set(
- map(
- lambda v: Mod(v, params.order),
- (ctx.points[parent] for parent in multiply_parents),
+ # Go over the parents and map them to multiples of the base (plus-minus sign)
+ init_multiples = set(
+ map(
+ lambda v: Mod(v, params.order),
+ (init_context.points[parent] for parent in init_parents),
+ )
)
- )
- multiply_multiples |= set(map(lambda v: -v, multiply_multiples))
- used = set()
- if use_init:
- used |= init_multiples
- if use_multiply:
- used |= multiply_multiples
- mults_to_multiples[mult] = used
+ init_multiples |= set(map(lambda v: -v, init_multiples))
+ # Now do the multiply and repeat the above, but only consider new computed points
+ with local(init_context) as ctx:
+ mult.multiply(scalar)
+ all_points = set(ctx.parents.keys())
+ multiply_parents = set(
+ sum((ctx.parents[point] for point in all_points - init_points), [])
+ )
+ multiply_multiples = set(
+ map(
+ lambda v: Mod(v, params.order),
+ (ctx.points[parent] for parent in multiply_parents),
+ )
+ )
+ multiply_multiples |= set(map(lambda v: -v, multiply_multiples))
+ used = set()
+ if use_init:
+ used |= init_multiples
+ if use_multiply:
+ used |= multiply_multiples
+ mults_to_multiples[mult] = used
+
+ dmap = Map.from_sets(set(mults), mults_to_multiples)
+ if tree is None:
+ tree = Tree.build(set(mults), dmap)
+ else:
+ tree = tree.expand(dmap)
+
+ log("Built distinguishing tree.")
+ log(tree.render())
+ scalars.append(scalar)
+ if not tree.precise:
+ done += 1
+ if done > tries:
+ warn(f"Tried more than {tries} times. Aborting. Distinguishing may not be precise.")
+ break
+ else:
+ continue
+ else:
+ break
+ self.scalars = scalars
+ self.tree = tree
+ self.params = params
+ self.P0 = P0
- dmap = Map.from_sets(set(mults), mults_to_multiples)
+ def run(
+ self, oracle: Callable[[int, Point], bool], majority: int = 1
+ ) -> Set[ScalarMultiplier]:
+ if self.tree is None or self.scalars is None:
+ raise ValueError("Need to build tree first.")
- tree = Tree.build(set(mults), dmap)
- log("Built distinguishing tree.")
- log(tree.render())
- if tree.size == 1:
- tries += 1
- continue
- current_node = tree.root
+ if (majority % 2) == 0:
+ raise ValueError("Cannot use even majority.")
+
+ current_node = self.tree.root
+ mults = current_node.cfgs
while current_node.children:
+ scalar = self.scalars[current_node.dmap_index] # type: ignore
best_distinguishing_multiple: Mod = current_node.dmap_input # type: ignore
- P0_inverse = rpa_input_point(best_distinguishing_multiple, P0, params)
+ P0_inverse = rpa_input_point(
+ best_distinguishing_multiple, self.P0, self.params
+ )
responses = []
+ response = True
for _ in range(majority):
responses.append(oracle(scalar, P0_inverse))
if responses.count(True) > (majority // 2):
@@ -285,16 +332,9 @@ def rpa_distinguish(
response = False
break
log(f"Oracle response -> {response}")
- for mult in mults:
- log(
- mult.__class__.__name__,
- best_distinguishing_multiple in mults_to_multiples[mult],
- )
response_map = {child.response: child for child in current_node.children}
current_node = response_map[response]
mults = current_node.cfgs
log([mult.__class__.__name__ for mult in mults])
log()
-
- if len(mults) == 1:
- return mults
+ return mults
diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py
index 2ae7945..4e9cf3e 100644
--- a/pyecsca/sca/re/tree.py
+++ b/pyecsca/sca/re/tree.py
@@ -199,7 +199,7 @@ class Map:
"Rows: {len(self.mapping)}, ({self.mapping.memory_usage(index=True).sum():_} bytes)",
"Inputs: {len(self.domain)}",
"Codomain: {len(self.codomain)}",
- "None in codomain: {None in self.codomain}"
+ "None in codomain: {None in self.codomain}",
)
)
@@ -264,6 +264,11 @@ class Tree:
"""Get the size of the tree (number of nodes)."""
return self.root.size
+ @property
+ def precise(self) -> bool:
+ """Whether the tree is precise (all leaves have only a single configuration)."""
+ return all(len(leaf.cfgs) == 1 for leaf in self.leaves)
+
def render(self) -> str:
"""Render the tree."""
style = AbstractStyle("\u2502 ", "\u251c\u2500\u2500", "\u2514\u2500\u2500")
diff --git a/test/sca/test_rpa.py b/test/sca/test_rpa.py
index 43f4183..af19ba2 100644
--- a/test/sca/test_rpa.py
+++ b/test/sca/test_rpa.py
@@ -180,8 +180,6 @@ def test_distinguish(secp128r1, add, dbl, neg):
map(lambda P: P.X == 0 or P.Y == 0, sum(ctx.parents.values(), []))
)
- with redirect_stdout(io.StringIO()):
- result = rpa_distinguish(secp128r1, multipliers, simulated_oracle)
+ result = rpa_distinguish(secp128r1, multipliers, simulated_oracle)
assert real_mult in result
assert 1 == len(result)
- assert real_mult == result.pop()