aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/countermeasures.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec/countermeasures.py')
-rw-r--r--pyecsca/ec/countermeasures.py128
1 files changed, 90 insertions, 38 deletions
diff --git a/pyecsca/ec/countermeasures.py b/pyecsca/ec/countermeasures.py
index 86f7177..853260c 100644
--- a/pyecsca/ec/countermeasures.py
+++ b/pyecsca/ec/countermeasures.py
@@ -1,7 +1,7 @@
"""Provides several countermeasures against side-channel attacks."""
from abc import ABC, abstractmethod
-from typing import Optional, Callable
+from typing import Optional, Callable, get_type_hints, ClassVar
from public import public
@@ -21,8 +21,10 @@ class ScalarMultiplierCountermeasure(ABC):
and provides some scalar-splitting countermeasure.
"""
- mult: "ScalarMultiplier | ScalarMultiplierCountermeasure"
- """The underlying scalar multiplier (or another countermeasure)."""
+ mults: list["ScalarMultiplier | ScalarMultiplierCountermeasure"]
+ """The underlying scalar multipliers (or another countermeasure)."""
+ nmults: ClassVar[int]
+ """The number of scalar multipliers required."""
params: Optional[DomainParameters]
"""The domain parameters, if any."""
point: Optional[Point]
@@ -32,10 +34,14 @@ class ScalarMultiplierCountermeasure(ABC):
def __init__(
self,
- mult: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ *mults: "ScalarMultiplier | ScalarMultiplierCountermeasure",
rng: Callable[[int], Mod] = Mod.random,
):
- self.mult = mult
+ self.mults = list(mults)
+ if len(self.mults) != self.nmults:
+ raise ValueError(
+ f"Expected {self.nmults} multipliers, got {len(self.mults)}."
+ )
self.rng = rng
def init(self, params: DomainParameters, point: Point, bits: Optional[int] = None):
@@ -60,6 +66,24 @@ class ScalarMultiplierCountermeasure(ABC):
"""
raise NotImplementedError
+ @classmethod
+ def from_single(
+ cls, mult: "ScalarMultiplier | ScalarMultiplierCountermeasure", **kwargs
+ ):
+ """
+ Create an instance of the countermeasure from a single scalar multiplier.
+
+ :param mult: The scalar multiplier to use.
+ :return: An instance of the countermeasure.
+ """
+ th = get_type_hints(cls.__init__)
+ num = 0
+ for name, arg_type in th.items():
+ if name.startswith("mult"):
+ num += 1
+ mults = [mult] * num
+ return cls(*mults, **kwargs)
+
@public
class GroupScalarRandomization(ScalarMultiplierCountermeasure):
@@ -75,7 +99,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure):
&\textbf{return}\ [k + r n]G
"""
-
+ nmults = 1
rand_bits: int
def __init__(
@@ -88,7 +112,7 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure):
:param mult: The multiplier to use.
:param rand_bits: How many random bits to sample.
"""
- super().__init__(mult, rng)
+ super().__init__(mult, rng=rng)
self.rand_bits = rand_bits
def multiply(self, scalar: int) -> Point:
@@ -99,12 +123,12 @@ class GroupScalarRandomization(ScalarMultiplierCountermeasure):
mask = int(self.rng(1 << self.rand_bits))
masked_scalar = scalar + mask * order
bits = max(self.bits, self.rand_bits + order.bit_length()) + 1
- self.mult.init(
+ self.mults[0].init(
self.params,
self.point,
bits=bits,
)
- return action.exit(self.mult.multiply(masked_scalar))
+ return action.exit(self.mults[0].multiply(masked_scalar))
@public
@@ -121,12 +145,13 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure):
&\textbf{return}\ [k - r]G + [r]G
"""
-
+ nmults = 2
add: Optional[AdditionFormula]
def __init__(
self,
- mult: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure",
rng: Callable[[int], Mod] = Mod.random,
add: Optional[AdditionFormula] = None,
):
@@ -134,14 +159,17 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure):
:param mult: The multiplier to use.
:param add: Addition formula to use, if None, the formula from the multiplier is used.
"""
- super().__init__(mult, rng)
+ super().__init__(mult1, mult2, rng=rng)
self.add = add
def _add(self, R: Point, S: Point) -> Point: # noqa
if self.add is None:
- try:
- return self.mult._add(R, S) # type: ignore
- except AttributeError:
+ for mult in self.mults:
+ try:
+ return mult._add(R, S) # type: ignore
+ except AttributeError:
+ pass
+ else:
raise ValueError("No addition formula available.")
else:
return self.add(
@@ -156,9 +184,14 @@ class AdditiveSplitting(ScalarMultiplierCountermeasure):
r = self.rng(order)
s = scalar - r
bits = max(self.bits, order.bit_length())
- self.mult.init(self.params, self.point, bits)
- R = self.mult.multiply(int(r))
- S = self.mult.multiply(int(s))
+
+ self.mults[0].init(self.params, self.point, bits)
+ R = self.mults[0].multiply(int(r))
+
+ if self.mults[0] != self.mults[1]:
+ self.mults[1].init(self.params, self.point, bits)
+ S = self.mults[1].multiply(int(s))
+
res = self._add(R, S)
return action.exit(res)
@@ -178,12 +211,13 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure):
&\textbf{return}\ [k r^{-1} \mod n]S
"""
-
+ nmults = 2
rand_bits: int
def __init__(
self,
- mult: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure",
rng: Callable[[int], Mod] = Mod.random,
rand_bits: int = 32,
):
@@ -191,7 +225,7 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure):
:param mult: The multiplier to use.
:param rand_bits: How many random bits to sample.
"""
- super().__init__(mult, rng)
+ super().__init__(mult1, mult2, rng=rng)
self.rand_bits = rand_bits
def multiply(self, scalar: int) -> Point:
@@ -199,14 +233,14 @@ class MultiplicativeSplitting(ScalarMultiplierCountermeasure):
raise ValueError("Not initialized.")
with ScalarMultiplicationAction(self.point, self.params, scalar) as action:
r = self.rng(1 << self.rand_bits)
- self.mult.init(self.params, self.point, self.rand_bits)
- R = self.mult.multiply(int(r))
+ self.mults[0].init(self.params, self.point, self.rand_bits)
+ R = self.mults[0].multiply(int(r))
- self.mult.init(
+ self.mults[1].init(
self.params, R, max(self.bits, self.params.order.bit_length())
)
kr_inv = scalar * mod(int(r), self.params.order).inverse()
- return action.exit(self.mult.multiply(int(kr_inv)))
+ return action.exit(self.mults[1].multiply(int(kr_inv)))
@public
@@ -227,12 +261,14 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure):
&\textbf{return}\ [k_1]G + [k_2]S
"""
-
+ nmults = 3
add: Optional[AdditionFormula]
def __init__(
self,
- mult: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult1: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult2: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ mult3: "ScalarMultiplier | ScalarMultiplierCountermeasure",
rng: Callable[[int], Mod] = Mod.random,
add: Optional[AdditionFormula] = None,
):
@@ -240,14 +276,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure):
:param mult: The multiplier to use.
:param add: Addition formula to use, if None, the formula from the multiplier is used.
"""
- super().__init__(mult, rng)
+ super().__init__(mult1, mult2, mult3, rng=rng)
self.add = add
def _add(self, R: Point, S: Point) -> Point: # noqa
if self.add is None:
- try:
- return self.mult._add(R, S) # type: ignore
- except AttributeError:
+ for mult in self.mults:
+ try:
+ return mult._add(R, S) # type: ignore
+ except AttributeError:
+ pass
+ else:
raise ValueError("No addition formula available.")
else:
return self.add(
@@ -260,15 +299,17 @@ class EuclideanSplitting(ScalarMultiplierCountermeasure):
with ScalarMultiplicationAction(self.point, self.params, scalar) as action:
half_bits = self.bits // 2
r = self.rng(1 << half_bits)
- self.mult.init(self.params, self.point, half_bits)
- R = self.mult.multiply(int(r)) # r bounded by half_bits
+ self.mults[0].init(self.params, self.point, half_bits)
+ R = self.mults[0].multiply(int(r)) # r bounded by half_bits
+ if self.mults[0] != self.mults[1]:
+ self.mults[1].init(self.params, self.point, half_bits)
k1 = scalar % int(r)
k2 = scalar // int(r)
- T = self.mult.multiply(k1) # k1 bounded by half_bits
+ T = self.mults[1].multiply(k1) # k1 bounded by half_bits
- self.mult.init(self.params, R, self.bits)
- S = self.mult.multiply(
+ self.mults[2].init(self.params, R, self.bits)
+ S = self.mults[2].multiply(
k2
) # k2 (in worst case) bounded by bits, but in practice closer to half_bits
@@ -294,13 +335,24 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure):
&\textbf{return}\ [\hat{k}]G
"""
+ nmults = 1
+
+ def __init__(
+ self,
+ mult: "ScalarMultiplier | ScalarMultiplierCountermeasure",
+ rng: Callable[[int], Mod] = Mod.random,
+ ):
+ """
+ :param mult: The multiplier to use.
+ """
+ super().__init__(mult, rng=rng)
def multiply(self, scalar: int) -> Point:
if self.params is None or self.point is None or self.bits is None:
raise ValueError("Not initialized.")
with ScalarMultiplicationAction(self.point, self.params, scalar) as action:
n = self.params.order
- self.mult.init(
+ self.mults[0].init(
self.params,
self.point,
bits=max(self.bits, n.bit_length()) + 1,
@@ -308,4 +360,4 @@ class BrumleyTuveri(ScalarMultiplierCountermeasure):
scalar += n
if scalar.bit_length() <= n.bit_length():
scalar += n
- return action.exit(self.mult.multiply(scalar))
+ return action.exit(self.mults[0].multiply(scalar))