diff options
| -rw-r--r-- | pyecsca/ec/countermeasures.py | 128 | ||||
| -rw-r--r-- | test/ec/test_countermeasures.py | 28 |
2 files changed, 107 insertions, 49 deletions
diff --git a/pyecsca/ec/countermeasures.py b/pyecsca/ec/countermeasures.py index 86f7177..853260c 100644 --- a/pyecsca/ec/countermeasures.py +++ b/pyecsca/ec/countermeasures.py @@ -1,7 +1,7 @@ """Provides several countermeasures against side-channel attacks.""" from abc import ABC, abstractmethod -from typing import Optional, Callable +from typing import Optional, Callable, get_type_hints, ClassVar from public import public @@ -21,8 +21,10 @@ class ScalarMultiplierCountermeasure(ABC): and provides some scalar-splitting countermeasure. """ - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure" - """The underlying scalar multiplier (or another countermeasure).""" + mults: list["ScalarMultiplier | ScalarMultiplierCountermeasure"] + """The underlying scalar multipliers (or another countermeasure).""" + nmults: ClassVar[int] + """The number of scalar multipliers required.""" params: Optional[DomainParameters] """The domain parameters, if any.""" point: Optional[Point] @@ -32,10 +34,14 @@ class ScalarMultiplierCountermeasure(ABC): def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + *mults: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, ): - self.mult = mult + self.mults = list(mults) + if len(self.mults) != self.nmults: + raise ValueError( + f"Expected {self.nmults} multipliers, got {len(self.mults)}." + ) self.rng = rng def init(self, params: DomainParameters, point: Point, bits: Optional[int] = None): @@ -60,6 +66,24 @@ class ScalarMultiplierCountermeasure(ABC): """ raise NotImplementedError + @classmethod + def from_single( + cls, mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", **kwargs + ): + """ + Create an instance of the countermeasure from a single scalar multiplier. + + :param mult: The scalar multiplier to use. + :return: An instance of the countermeasure. + """ + th = get_type_hints(cls.__init__) + num = 0 + for name, arg_type in th.items(): + if name.startswith("mult"): + num += 1 + mults = [mult] * num + return cls(*mults, **kwargs) + @public class GroupScalarRandomization(ScalarMultiplierCountermeasure): @@ -75,7 +99,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): &\textbf{return}\ [k + r n]G """ - + nmults = 1 rand_bits: int def __init__( @@ -88,7 +112,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param rand_bits: How many random bits to sample. """ - super().__init__(mult, rng) + super().__init__(mult, rng=rng) self.rand_bits = rand_bits def multiply(self, scalar: int) -> Point: @@ -99,12 +123,12 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure): mask = int(self.rng(1 << self.rand_bits)) masked_scalar = scalar + mask * order bits = max(self.bits, self.rand_bits + order.bit_length()) + 1 - self.mult.init( + self.mults[0].init( self.params, self.point, bits=bits, ) - return action.exit(self.mult.multiply(masked_scalar)) + return action.exit(self.mults[0].multiply(masked_scalar)) @public @@ -121,12 +145,13 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k - r]G + [r]G """ - + nmults = 2 add: Optional[AdditionFormula] def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, add: Optional[AdditionFormula] = None, ): @@ -134,14 +159,17 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param add: Addition formula to use, if None, the formula from the multiplier is used. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, rng=rng) self.add = add def _add(self, R: Point, S: Point) -> Point: # noqa if self.add is None: - try: - return self.mult._add(R, S) # type: ignore - except AttributeError: + for mult in self.mults: + try: + return mult._add(R, S) # type: ignore + except AttributeError: + pass + else: raise ValueError("No addition formula available.") else: return self.add( @@ -156,9 +184,14 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure): r = self.rng(order) s = scalar - r bits = max(self.bits, order.bit_length()) - self.mult.init(self.params, self.point, bits) - R = self.mult.multiply(int(r)) - S = self.mult.multiply(int(s)) + + self.mults[0].init(self.params, self.point, bits) + R = self.mults[0].multiply(int(r)) + + if self.mults[0] != self.mults[1]: + self.mults[1].init(self.params, self.point, bits) + S = self.mults[1].multiply(int(s)) + res = self._add(R, S) return action.exit(res) @@ -178,12 +211,13 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k r^{-1} \mod n]S """ - + nmults = 2 rand_bits: int def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, rand_bits: int = 32, ): @@ -191,7 +225,7 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param rand_bits: How many random bits to sample. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, rng=rng) self.rand_bits = rand_bits def multiply(self, scalar: int) -> Point: @@ -199,14 +233,14 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure): raise ValueError("Not initialized.") with ScalarMultiplicationAction(self.point, self.params, scalar) as action: r = self.rng(1 << self.rand_bits) - self.mult.init(self.params, self.point, self.rand_bits) - R = self.mult.multiply(int(r)) + self.mults[0].init(self.params, self.point, self.rand_bits) + R = self.mults[0].multiply(int(r)) - self.mult.init( + self.mults[1].init( self.params, R, max(self.bits, self.params.order.bit_length()) ) kr_inv = scalar * mod(int(r), self.params.order).inverse() - return action.exit(self.mult.multiply(int(kr_inv))) + return action.exit(self.mults[1].multiply(int(kr_inv))) @public @@ -227,12 +261,14 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): &\textbf{return}\ [k_1]G + [k_2]S """ - + nmults = 3 add: Optional[AdditionFormula] def __init__( self, - mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure", + mult3: "ScalarMultiplier | ScalarMultiplierCountermeasure", rng: Callable[[int], Mod] = Mod.random, add: Optional[AdditionFormula] = None, ): @@ -240,14 +276,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): :param mult: The multiplier to use. :param add: Addition formula to use, if None, the formula from the multiplier is used. """ - super().__init__(mult, rng) + super().__init__(mult1, mult2, mult3, rng=rng) self.add = add def _add(self, R: Point, S: Point) -> Point: # noqa if self.add is None: - try: - return self.mult._add(R, S) # type: ignore - except AttributeError: + for mult in self.mults: + try: + return mult._add(R, S) # type: ignore + except AttributeError: + pass + else: raise ValueError("No addition formula available.") else: return self.add( @@ -260,15 +299,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure): with ScalarMultiplicationAction(self.point, self.params, scalar) as action: half_bits = self.bits // 2 r = self.rng(1 << half_bits) - self.mult.init(self.params, self.point, half_bits) - R = self.mult.multiply(int(r)) # r bounded by half_bits + self.mults[0].init(self.params, self.point, half_bits) + R = self.mults[0].multiply(int(r)) # r bounded by half_bits + if self.mults[0] != self.mults[1]: + self.mults[1].init(self.params, self.point, half_bits) k1 = scalar % int(r) k2 = scalar // int(r) - T = self.mult.multiply(k1) # k1 bounded by half_bits + T = self.mults[1].multiply(k1) # k1 bounded by half_bits - self.mult.init(self.params, R, self.bits) - S = self.mult.multiply( + self.mults[2].init(self.params, R, self.bits) + S = self.mults[2].multiply( k2 ) # k2 (in worst case) bounded by bits, but in practice closer to half_bits @@ -294,13 +335,24 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure): &\textbf{return}\ [\hat{k}]G """ + nmults = 1 + + def __init__( + self, + mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", + rng: Callable[[int], Mod] = Mod.random, + ): + """ + :param mult: The multiplier to use. + """ + super().__init__(mult, rng=rng) def multiply(self, scalar: int) -> Point: if self.params is None or self.point is None or self.bits is None: raise ValueError("Not initialized.") with ScalarMultiplicationAction(self.point, self.params, scalar) as action: n = self.params.order - self.mult.init( + self.mults[0].init( self.params, self.point, bits=max(self.bits, n.bit_length()) + 1, @@ -308,4 +360,4 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure): scalar += n if scalar.bit_length() <= n.bit_length(): scalar += n - return action.exit(self.mult.multiply(scalar)) + return action.exit(self.mults[0].multiply(scalar)) diff --git a/test/ec/test_countermeasures.py b/test/ec/test_countermeasures.py index 62db0d4..946261c 100644 --- a/test/ec/test_countermeasures.py +++ b/test/ec/test_countermeasures.py @@ -183,7 +183,7 @@ def test_additive_splitting(mults, secp128r1, num): raw = mult.multiply(num) for mult in mults: - asplit = AdditiveSplitting(mult) + asplit = AdditiveSplitting(mult, mult) asplit.init(secp128r1, secp128r1.generator) masked = asplit.multiply(num) assert raw.equals(masked) @@ -202,7 +202,7 @@ def test_multiplicative_splitting(mults, secp128r1, num): raw = mult.multiply(num) for mult in mults: - msplit = MultiplicativeSplitting(mult) + msplit = MultiplicativeSplitting(mult, mult) msplit.init(secp128r1, secp128r1.generator) masked = msplit.multiply(num) assert raw.equals(masked) @@ -221,7 +221,7 @@ def test_euclidean_splitting(mults, secp128r1, num): raw = mult.multiply(num) for mult in mults: - esplit = EuclideanSplitting(mult) + esplit = EuclideanSplitting(mult, mult, mult) esplit.init(secp128r1, secp128r1.generator) masked = esplit.multiply(num) assert raw.equals(masked) @@ -274,6 +274,7 @@ def test_combination(scalar, one, two, secp128r1): mult = LTRMultiplier( secp128r1.curve.coordinate_model.formulas["add-2015-rcb"], secp128r1.curve.coordinate_model.formulas["dbl-2015-rcb"], + scl=secp128r1.curve.coordinate_model.formulas["z"], ) mult.init(secp128r1, secp128r1.generator) raw = mult.multiply(scalar) @@ -281,17 +282,22 @@ def test_combination(scalar, one, two, secp128r1): add = mult.formulas["add"] if one in (AdditiveSplitting, EuclideanSplitting): - layer_one = one(mult, add=add) + layer_one = one.from_single(mult, add=add) else: - layer_one = one(mult) + layer_one = one.from_single(mult) if two in (AdditiveSplitting, EuclideanSplitting): - combo = two(layer_one, add=add) + kws = {"add": add} else: - combo = two(layer_one) - combo.init(secp128r1, secp128r1.generator) - masked = combo.multiply(scalar) - assert raw.equals(masked) + kws = {} + + for i in range(2**two.nmults): + bits = format(i, f"0{two.nmults}b") + args = [layer_one if bit == "1" else mult for bit in bits] + combo = two(*args, **kws) + combo.init(secp128r1, secp128r1.generator) + masked = combo.multiply(scalar) + assert raw.equals(masked) @pytest.mark.parametrize( @@ -324,7 +330,7 @@ def test_rng(scalar, ctr, secp128r1): def rng(n): return mod(123456789, n) - m = ctr(mult, rng) + m = ctr.from_single(mult, rng=rng) m.init(secp128r1, secp128r1.generator) masked = m.multiply(scalar) assert raw.equals(masked) |
