diff options
| -rw-r--r-- | pyecsca/ec/mult.py | 27 | ||||
| -rw-r--r-- | pyecsca/sca/re/rpa.py | 74 |
2 files changed, 94 insertions, 7 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py index 9b13374..2effb10 100644 --- a/pyecsca/ec/mult.py +++ b/pyecsca/ec/mult.py @@ -155,6 +155,9 @@ class ScalarMultiplier(ABC): self._params.curve.prime, point, **self._params.curve.parameters )[0] + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, ScalarMultiplier): return False @@ -223,6 +226,9 @@ class LTRMultiplier(ScalarMultiplier): self.always = always self.complete = complete + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, LTRMultiplier): return False @@ -277,6 +283,9 @@ class RTLMultiplier(ScalarMultiplier): super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl) self.always = always + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, RTLMultiplier): return False @@ -326,6 +335,9 @@ class CoronMultiplier(ScalarMultiplier): ): super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl) + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, CoronMultiplier): return False @@ -370,6 +382,9 @@ class LadderMultiplier(ScalarMultiplier): if (not complete or short_circuit) and dbl is None: raise ValueError + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, LadderMultiplier): return False @@ -419,6 +434,9 @@ class SimpleLadderMultiplier(ScalarMultiplier): super().__init__(short_circuit=short_circuit, add=add, dbl=dbl, scl=scl) self.complete = complete + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, SimpleLadderMultiplier): return False @@ -467,6 +485,9 @@ class DifferentialLadderMultiplier(ScalarMultiplier): super().__init__(short_circuit=short_circuit, dadd=dadd, dbl=dbl, scl=scl) self.complete = complete + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, DifferentialLadderMultiplier): return False @@ -517,6 +538,9 @@ class BinaryNAFMultiplier(ScalarMultiplier): short_circuit=short_circuit, add=add, dbl=dbl, neg=neg, scl=scl ) + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, BinaryNAFMultiplier): return False @@ -573,6 +597,9 @@ class WindowNAFMultiplier(ScalarMultiplier): self.width = width self.precompute_negation = precompute_negation + def __hash__(self): + return id(self) + def __eq__(self, other): if not isinstance(other, WindowNAFMultiplier): return False diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py index 7a7ca17..5fd07c6 100644 --- a/pyecsca/sca/re/rpa.py +++ b/pyecsca/sca/re/rpa.py @@ -5,7 +5,8 @@ Provides functionality inspired by the Refined-Power Analysis attack by Goubin. `<https://dl.acm.org/doi/10.5555/648120.747060>`_ """ from public import public -from typing import MutableMapping, Optional +from typing import MutableMapping, Optional, Callable +from collections import Counter from sympy import FF, sympify, Poly, symbols @@ -20,11 +21,11 @@ from ...ec.formula import ( LadderFormula, ) from ...ec.mod import Mod -from ...ec.mult import ScalarMultiplicationAction, PrecomputationAction +from ...ec.mult import ScalarMultiplicationAction, PrecomputationAction, ScalarMultiplier from ...ec.params import DomainParameters from ...ec.model import ShortWeierstrassModel, MontgomeryModel from ...ec.point import Point -from ...ec.context import Context, Action +from ...ec.context import Context, Action, local @public @@ -87,8 +88,8 @@ class MultipleContext(Context): return f"{self.__class__.__name__}({self.base!r}, multiples={self.points.values()!r})" -def rpa_point_0y(params: DomainParameters): - """Construct a RPA point (0, y) for given domain parameters.""" +def rpa_point_0y(params: DomainParameters) -> Optional[Point]: + """Construct an (affine) RPA point (0, y) for given domain parameters.""" if isinstance(params.curve.model, ShortWeierstrassModel): if not params.curve.parameters["b"].is_residue(): return None @@ -102,8 +103,8 @@ def rpa_point_0y(params: DomainParameters): raise NotImplementedError -def rpa_point_x0(params: DomainParameters): - """Construct a RPA point (x, 0) for given domain parameters.""" +def rpa_point_x0(params: DomainParameters) -> Optional[Point]: + """Construct an (affine) RPA point (x, 0) for given domain parameters.""" if isinstance(params.curve.model, ShortWeierstrassModel): if (params.order * params.cofactor) % 2 != 0: return None @@ -123,3 +124,62 @@ def rpa_point_x0(params: DomainParameters): y=Mod(0, params.curve.prime)) else: raise NotImplementedError + + +def rpa_distinguish(params: DomainParameters, mults: list[ScalarMultiplier], oracle: Callable[[int, Point], bool]) -> list[ScalarMultiplier]: + """ + Distinguish the scalar multiplier used (from the possible :paramref:`~.rpa_distinguish.mults`) using + an RPA :paramref:`~.rpa_distinguish.oracle`. + + :param params: The domain parameters to use. + :param mults: The list of possible multipliers. + :param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar. + :returns: The list of possible multipliers after distinguishing (ideally just one). + """ + P0 = rpa_point_x0(params) or rpa_point_0y(params) + if not P0: + raise ValueError("There are no RPA-points on the provided curve.") + print(f"Got RPA point {P0}") + while True: + scalar = int(Mod.random(params.order)) + print(f"Got scalar {scalar}") + print([mult.__class__.__name__ for mult in mults]) + mults_to_multiples = {} + counts: Counter = Counter() + for mult in mults: + with local(MultipleContext()) as ctx: + mult.init(params, params.generator) + mult.multiply(scalar) + multiples = set(ctx.points.values()) + mults_to_multiples[mult] = multiples + counts.update(multiples) + + # TODO: This lower part can be repeated a few times for the same scalar above, which could reuse + # the computed multiples. Can be done until there is some distinguishing multiple. + # However, the counts variable needs to be recomputed for the new subset of multipliers. + nhalf = len(mults) / 2 + best_distinguishing_multiple = None + best_count = None + best_nhalf_distance = None + for multiple, count in counts.items(): + if best_distinguishing_multiple is None or abs(count - nhalf) < best_nhalf_distance: + best_distinguishing_multiple = multiple + best_count = count + best_nhalf_distance = abs(count - nhalf) + print(f"Chosen best distinguishing multiple {best_distinguishing_multiple} count={best_count} n={len(mults)}") + if best_count in (0, len(mults)): + continue + + multiple_inverse = Mod(best_distinguishing_multiple, params.order).inverse() + P0_inverse = params.curve.affine_multiply(P0, int(multiple_inverse)) + response = oracle(scalar, P0_inverse) + print(f"Oracle response -> {response}") + for mult in mults: + print(mult.__class__.__name__, best_distinguishing_multiple in mults_to_multiples[mult]) + filt = (lambda mult: best_distinguishing_multiple in mults_to_multiples[mult]) if response else (lambda mult: best_distinguishing_multiple not in mults_to_multiples[mult]) + mults = list(filter(filt, mults)) + print([mult.__class__.__name__ for mult in mults]) + print() + + if len(mults) == 1: + return mults |
