diff options
Diffstat (limited to 'pyecsca/ec/countermeasures.py')
| -rw-r--r-- | pyecsca/ec/countermeasures.py | 128 |
1 files changed, 90 insertions, 38 deletions
diff --git a/pyecsca/ec/countermeasures.py b/pyecsca/ec/countermeasures.py index 86f7177..853260c 100644 --- a/pyecsca/ec/countermeasures.py +++ b/pyecsca/ec/countermeasures.py @@ -1,7 +1,7 @@ """Provides several countermeasures against side-channel attacks.""" from abc import ABC, abstractmethod -from typing import Optional, Callable +from typing import Optional, Callable, get_type_hints, ClassVar from public import public @@ -21,8 +21,10 @@ class ScalarMultiplierCountermeasure(ABC): and provides some scalar-splitting countermeasure. """ - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure" - """The underlying scalar multiplier (or another countermeasure).""" + mults: list["ScalarMultiplier | ScalarMultiplierCountermeasure"] + """The underlying scalar multipliers (or another countermeasure).""" + nmults: ClassVar[int] + """The number of scalar multipliers required.""" params: Optional[DomainParameters] """The domain parameters, if any.""" point: Optional[Point] @@ -32,10 +34,14 @@ class ScalarMultiplierCountermeasure(ABC): def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + *mults: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, ): - self.mult = mult + self.mults = list(mults) + if len(self.mults) != self.nmults: + raise ValueError( + f"Expected {self.nmults} multipliers, got {len(self.mults)}." + ) self.rng = rng def init(self, params: DomainParameters, point: Point, bits: Optional[int] = None): @@ -60,6 +66,24 @@ class ScalarMultiplierCountermeasure(ABC): """ raise NotImplementedError + @classmethod + def from_single( + cls, mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", **kwargs + ): + """ + Create an instance of the countermeasure from a single scalar multiplier. + + :param mult: The scalar multiplier to use. + :return: An instance of the countermeasure. + """ + th = get_type_hints(cls.__init__) + num = 0 + for name, arg_type in th.items(): + if name.startswith("mult"): + num += 1 + mults = [mult] * num + return cls(*mults, **kwargs) + @public class GroupScalarRandomization(ScalarMultiplierCountermeasure): @@ -75,7 +99,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): &\textbf{return}\ [k + r n]G """ - + nmults = 1 rand_bits: int def __init__( @@ -88,7 +112,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param rand_bits: How many random bits to sample. """ - super().__init__(mult, rng) + super().__init__(mult, rng=rng) self.rand_bits = rand_bits def multiply(self, scalar: int) -> Point: @@ -99,12 +123,12 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): mask = int(self.rng(1 << self.rand_bits)) masked_scalar = scalar + mask * order bits = max(self.bits, self.rand_bits + order.bit_length()) + 1 - self.mult.init( + self.mults[0].init( self.params, self.point, bits=bits, ) - return action.exit(self.mult.multiply(masked_scalar)) + return action.exit(self.mults[0].multiply(masked_scalar)) @public @@ -121,12 +145,13 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k - r]G + [r]G """ - + nmults = 2 add: Optional[AdditionFormula] def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, add: Optional[AdditionFormula] = None, ): @@ -134,14 +159,17 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param add: Addition formula to use, if None, the formula from the multiplier is used. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, rng=rng) self.add = add def _add(self, R: Point, S: Point) -> Point: # noqa if self.add is None: - try: - return self.mult._add(R, S) # type: ignore - except AttributeError: + for mult in self.mults: + try: + return mult._add(R, S) # type: ignore + except AttributeError: + pass + else: raise ValueError("No addition formula available.") else: return self.add( @@ -156,9 +184,14 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): r = self.rng(order) s = scalar - r bits = max(self.bits, order.bit_length()) - self.mult.init(self.params, self.point, bits) - R = self.mult.multiply(int(r)) - S = self.mult.multiply(int(s)) + + self.mults[0].init(self.params, self.point, bits) + R = self.mults[0].multiply(int(r)) + + if self.mults[0] != self.mults[1]: + self.mults[1].init(self.params, self.point, bits) + S = self.mults[1].multiply(int(s)) + res = self._add(R, S) return action.exit(res) @@ -178,12 +211,13 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k r^{-1} \mod n]S """ - + nmults = 2 rand_bits: int def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, rand_bits: int = 32, ): @@ -191,7 +225,7 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param rand_bits: How many random bits to sample. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, rng=rng) self.rand_bits = rand_bits def multiply(self, scalar: int) -> Point: @@ -199,14 +233,14 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): raise ValueError("Not initialized.") with ScalarMultiplicationAction(self.point, self.params, scalar) as action: r = self.rng(1 << self.rand_bits) - self.mult.init(self.params, self.point, self.rand_bits) - R = self.mult.multiply(int(r)) + self.mults[0].init(self.params, self.point, self.rand_bits) + R = self.mults[0].multiply(int(r)) - self.mult.init( + self.mults[1].init( self.params, R, max(self.bits, self.params.order.bit_length()) ) kr_inv = scalar * mod(int(r), self.params.order).inverse() - return action.exit(self.mult.multiply(int(kr_inv))) + return action.exit(self.mults[1].multiply(int(kr_inv))) @public @@ -227,12 +261,14 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k_1]G + [k_2]S """ - + nmults = 3 add: Optional[AdditionFormula] def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult3: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, add: Optional[AdditionFormula] = None, ): @@ -240,14 +276,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param add: Addition formula to use, if None, the formula from the multiplier is used. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, mult3, rng=rng) self.add = add def _add(self, R: Point, S: Point) -> Point: # noqa if self.add is None: - try: - return self.mult._add(R, S) # type: ignore - except AttributeError: + for mult in self.mults: + try: + return mult._add(R, S) # type: ignore + except AttributeError: + pass + else: raise ValueError("No addition formula available.") else: return self.add( @@ -260,15 +299,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): with ScalarMultiplicationAction(self.point, self.params, scalar) as action: half_bits = self.bits // 2 r = self.rng(1 << half_bits) - self.mult.init(self.params, self.point, half_bits) - R = self.mult.multiply(int(r)) # r bounded by half_bits + self.mults[0].init(self.params, self.point, half_bits) + R = self.mults[0].multiply(int(r)) # r bounded by half_bits + if self.mults[0] != self.mults[1]: + self.mults[1].init(self.params, self.point, half_bits) k1 = scalar % int(r) k2 = scalar // int(r) - T = self.mult.multiply(k1) # k1 bounded by half_bits + T = self.mults[1].multiply(k1) # k1 bounded by half_bits - self.mult.init(self.params, R, self.bits) - S = self.mult.multiply( + self.mults[2].init(self.params, R, self.bits) + S = self.mults[2].multiply( k2 ) # k2 (in worst case) bounded by bits, but in practice closer to half_bits @@ -294,13 +335,24 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure): &\textbf{return}\ [\hat{k}]G """ + nmults = 1 + + def __init__( + self, + mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + rng: Callable[[int], Mod] = Mod.random, + ): + """ + :param mult: The multiplier to use. + """ + super().__init__(mult, rng=rng) def multiply(self, scalar: int) -> Point: if self.params is None or self.point is None or self.bits is None: raise ValueError("Not initialized.") with ScalarMultiplicationAction(self.point, self.params, scalar) as action: n = self.params.order - self.mult.init( + self.mults[0].init( self.params, self.point, bits=max(self.bits, n.bit_length()) + 1, @@ -308,4 +360,4 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure): scalar += n if scalar.bit_length() <= n.bit_length(): scalar += n - return action.exit(self.mult.multiply(scalar)) + return action.exit(self.mults[0].multiply(scalar)) |
