aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pyecsca/ec/mult.py27
-rw-r--r--pyecsca/sca/re/rpa.py74
2 files changed, 94 insertions, 7 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py
index 9b13374..2effb10 100644
--- a/pyecsca/ec/mult.py
+++ b/pyecsca/ec/mult.py
@@ -155,6 +155,9 @@ class ScalarMultiplier(ABC):
self._params.curve.prime, point, **self._params.curve.parameters
)[0]
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, ScalarMultiplier):
return False
@@ -223,6 +226,9 @@ class LTRMultiplier(ScalarMultiplier):
self.always = always
self.complete = complete
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, LTRMultiplier):
return False
@@ -277,6 +283,9 @@ class RTLMultiplier(ScalarMultiplier):
super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl)
self.always = always
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, RTLMultiplier):
return False
@@ -326,6 +335,9 @@ class CoronMultiplier(ScalarMultiplier):
):
super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl)
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, CoronMultiplier):
return False
@@ -370,6 +382,9 @@ class LadderMultiplier(ScalarMultiplier):
if (not complete or short_circuit) and dbl is None:
raise ValueError
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, LadderMultiplier):
return False
@@ -419,6 +434,9 @@ class SimpleLadderMultiplier(ScalarMultiplier):
super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl)
self.complete = complete
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, SimpleLadderMultiplier):
return False
@@ -467,6 +485,9 @@ class DifferentialLadderMultiplier(ScalarMultiplier):
super().__init__(short_circuit=short_circuit, dadd=dadd, dbl=dbl, scl=scl)
self.complete = complete
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, DifferentialLadderMultiplier):
return False
@@ -517,6 +538,9 @@ class BinaryNAFMultiplier(ScalarMultiplier):
short_circuit=short_circuit, add=add, dbl=dbl, neg=neg, scl=scl
)
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, BinaryNAFMultiplier):
return False
@@ -573,6 +597,9 @@ class WindowNAFMultiplier(ScalarMultiplier):
self.width = width
self.precompute_negation = precompute_negation
+ def __hash__(self):
+ return id(self)
+
def __eq__(self, other):
if not isinstance(other, WindowNAFMultiplier):
return False
diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py
index 7a7ca17..5fd07c6 100644
--- a/pyecsca/sca/re/rpa.py
+++ b/pyecsca/sca/re/rpa.py
@@ -5,7 +5,8 @@ Provides functionality inspired by the Refined-Power Analysis attack by Goubin.
`<https://dl.acm.org/doi/10.5555/648120.747060>`_
"""
from public import public
-from typing import MutableMapping, Optional
+from typing import MutableMapping, Optional, Callable
+from collections import Counter
from sympy import FF, sympify, Poly, symbols
@@ -20,11 +21,11 @@ from ...ec.formula import (
LadderFormula,
)
from ...ec.mod import Mod
-from ...ec.mult import ScalarMultiplicationAction, PrecomputationAction
+from ...ec.mult import ScalarMultiplicationAction, PrecomputationAction, ScalarMultiplier
from ...ec.params import DomainParameters
from ...ec.model import ShortWeierstrassModel, MontgomeryModel
from ...ec.point import Point
-from ...ec.context import Context, Action
+from ...ec.context import Context, Action, local
@public
@@ -87,8 +88,8 @@ class MultipleContext(Context):
return f"{self.__class__.__name__}({self.base!r}, multiples={self.points.values()!r})"
-def rpa_point_0y(params: DomainParameters):
- """Construct a RPA point (0, y) for given domain parameters."""
+def rpa_point_0y(params: DomainParameters) -> Optional[Point]:
+ """Construct an (affine) RPA point (0, y) for given domain parameters."""
if isinstance(params.curve.model, ShortWeierstrassModel):
if not params.curve.parameters["b"].is_residue():
return None
@@ -102,8 +103,8 @@ def rpa_point_0y(params: DomainParameters):
raise NotImplementedError
-def rpa_point_x0(params: DomainParameters):
- """Construct a RPA point (x, 0) for given domain parameters."""
+def rpa_point_x0(params: DomainParameters) -> Optional[Point]:
+ """Construct an (affine) RPA point (x, 0) for given domain parameters."""
if isinstance(params.curve.model, ShortWeierstrassModel):
if (params.order * params.cofactor) % 2 != 0:
return None
@@ -123,3 +124,62 @@ def rpa_point_x0(params: DomainParameters):
y=Mod(0, params.curve.prime))
else:
raise NotImplementedError
+
+
+def rpa_distinguish(params: DomainParameters, mults: list[ScalarMultiplier], oracle: Callable[[int, Point], bool]) -> list[ScalarMultiplier]:
+ """
+ Distinguish the scalar multiplier used (from the possible :paramref:`~.rpa_distinguish.mults`) using
+ an RPA :paramref:`~.rpa_distinguish.oracle`.
+
+ :param params: The domain parameters to use.
+ :param mults: The list of possible multipliers.
+ :param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar.
+ :returns: The list of possible multipliers after distinguishing (ideally just one).
+ """
+ P0 = rpa_point_x0(params) or rpa_point_0y(params)
+ if not P0:
+ raise ValueError("There are no RPA-points on the provided curve.")
+ print(f"Got RPA point {P0}")
+ while True:
+ scalar = int(Mod.random(params.order))
+ print(f"Got scalar {scalar}")
+ print([mult.__class__.__name__ for mult in mults])
+ mults_to_multiples = {}
+ counts: Counter = Counter()
+ for mult in mults:
+ with local(MultipleContext()) as ctx:
+ mult.init(params, params.generator)
+ mult.multiply(scalar)
+ multiples = set(ctx.points.values())
+ mults_to_multiples[mult] = multiples
+ counts.update(multiples)
+
+ # TODO: This lower part can be repeated a few times for the same scalar above, which could reuse
+ # the computed multiples. Can be done until there is some distinguishing multiple.
+ # However, the counts variable needs to be recomputed for the new subset of multipliers.
+ nhalf = len(mults) / 2
+ best_distinguishing_multiple = None
+ best_count = None
+ best_nhalf_distance = None
+ for multiple, count in counts.items():
+ if best_distinguishing_multiple is None or abs(count - nhalf) < best_nhalf_distance:
+ best_distinguishing_multiple = multiple
+ best_count = count
+ best_nhalf_distance = abs(count - nhalf)
+ print(f"Chosen best distinguishing multiple {best_distinguishing_multiple} count={best_count} n={len(mults)}")
+ if best_count in (0, len(mults)):
+ continue
+
+ multiple_inverse = Mod(best_distinguishing_multiple, params.order).inverse()
+ P0_inverse = params.curve.affine_multiply(P0, int(multiple_inverse))
+ response = oracle(scalar, P0_inverse)
+ print(f"Oracle response -> {response}")
+ for mult in mults:
+ print(mult.__class__.__name__, best_distinguishing_multiple in mults_to_multiples[mult])
+ filt = (lambda mult: best_distinguishing_multiple in mults_to_multiples[mult]) if response else (lambda mult: best_distinguishing_multiple not in mults_to_multiples[mult])
+ mults = list(filter(filt, mults))
+ print([mult.__class__.__name__ for mult in mults])
+ print()
+
+ if len(mults) == 1:
+ return mults