aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/mult.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec/mult.py')
-rw-r--r--pyecsca/ec/mult.py301
1 files changed, 160 insertions, 141 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py
index 4f937fa..4c5b6d9 100644
--- a/pyecsca/ec/mult.py
+++ b/pyecsca/ec/mult.py
@@ -1,16 +1,31 @@
from copy import copy
-from typing import Mapping, Tuple, Optional, MutableMapping, Union, ClassVar, Set, Type
+from typing import Mapping, Tuple, Optional, MutableMapping, ClassVar, Set, Type
from public import public
-from .context import getcontext
+from .context import Action
from .formula import (Formula, AdditionFormula, DoublingFormula, DifferentialAdditionFormula,
ScalingFormula, LadderFormula, NegationFormula)
-from .params import DomainParameters
from .naf import naf, wnaf
+from .params import DomainParameters
from .point import Point
+@public
+class ScalarMultiplicationAction(Action):
+ """A scalar multiplication of a point on a curve by a scalar."""
+ point: Point
+ scalar: int
+
+ def __init__(self, point: Point, scalar: int):
+ super().__init__()
+ self.point = point
+ self.scalar = scalar
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({self.point}, {self.scalar})"
+
+
class ScalarMultiplier(object):
"""
A scalar multiplication algorithm.
@@ -42,9 +57,7 @@ class ScalarMultiplier(object):
return copy(other)
if other == self._group.neutral:
return copy(one)
- return \
- getcontext().execute(self.formulas["add"], one, other, **self._group.curve.parameters)[
- 0]
+ return self.formulas["add"](one, other, **self._group.curve.parameters)[0]
def _dbl(self, point: Point) -> Point:
if "dbl" not in self.formulas:
@@ -52,12 +65,12 @@ class ScalarMultiplier(object):
if self.short_circuit:
if point == self._group.neutral:
return copy(point)
- return getcontext().execute(self.formulas["dbl"], point, **self._group.curve.parameters)[0]
+ return self.formulas["dbl"](point, **self._group.curve.parameters)[0]
def _scl(self, point: Point) -> Point:
if "scl" not in self.formulas:
raise NotImplementedError
- return getcontext().execute(self.formulas["scl"], point, **self._group.curve.parameters)[0]
+ return self.formulas["scl"](point, **self._group.curve.parameters)[0]
def _ladd(self, start: Point, to_dbl: Point, to_add: Point) -> Tuple[Point, ...]:
if "ladd" not in self.formulas:
@@ -67,8 +80,7 @@ class ScalarMultiplier(object):
return to_dbl, to_add
if to_add == self._group.neutral:
return self._dbl(to_dbl), to_dbl
- return getcontext().execute(self.formulas["ladd"], start, to_dbl, to_add,
- **self._group.curve.parameters)
+ return self.formulas["ladd"](start, to_dbl, to_add, **self._group.curve.parameters)
def _dadd(self, start: Point, one: Point, other: Point) -> Point:
if "dadd" not in self.formulas:
@@ -78,13 +90,12 @@ class ScalarMultiplier(object):
return copy(other)
if other == self._group.neutral:
return copy(one)
- return getcontext().execute(self.formulas["dadd"], start, one, other,
- **self._group.curve.parameters)[0]
+ return self.formulas["dadd"](start, one, other, **self._group.curve.parameters)[0]
def _neg(self, point: Point) -> Point:
if "neg" not in self.formulas:
raise NotImplementedError
- return getcontext().execute(self.formulas["neg"], point, **self._group.curve.parameters)[0]
+ return self.formulas["neg"](point, **self._group.curve.parameters)[0]
def init(self, group: DomainParameters, point: Point):
"""Initialize the scalar multiplier with a group and a point."""
@@ -122,25 +133,26 @@ class LTRMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- if self.complete:
- q = self._point
- r = copy(self._group.neutral)
- top = self._group.order.bit_length() - 1
- else:
- q = self._dbl(self._point)
- r = copy(self._point)
- top = scalar.bit_length() - 2
- for i in range(top, -1, -1):
- r = self._dbl(r)
- if scalar & (1 << i) != 0:
- r = self._add(r, q)
- elif self.always:
- self._add(r, q)
- if "scl" in self.formulas:
- r = self._scl(r)
- return r
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ if self.complete:
+ q = self._point
+ r = copy(self._group.neutral)
+ top = self._group.order.bit_length() - 1
+ else:
+ q = self._dbl(self._point)
+ r = copy(self._point)
+ top = scalar.bit_length() - 2
+ for i in range(top, -1, -1):
+ r = self._dbl(r)
+ if scalar & (1 << i) != 0:
+ r = self._add(r, q)
+ elif self.always:
+ self._add(r, q)
+ if "scl" in self.formulas:
+ r = self._scl(r)
+ return r
@public
@@ -162,20 +174,21 @@ class RTLMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- q = self._point
- r = copy(self._group.neutral)
- while scalar > 0:
- if scalar & 1 != 0:
- r = self._add(r, q)
- elif self.always:
- self._add(r, q)
- q = self._dbl(q)
- scalar >>= 1
- if "scl" in self.formulas:
- r = self._scl(r)
- return r
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ q = self._point
+ r = copy(self._group.neutral)
+ while scalar > 0:
+ if scalar & 1 != 0:
+ r = self._add(r, q)
+ elif self.always:
+ self._add(r, q)
+ q = self._dbl(q)
+ scalar >>= 1
+ if "scl" in self.formulas:
+ r = self._scl(r)
+ return r
class CoronMultiplier(ScalarMultiplier):
@@ -196,18 +209,19 @@ class CoronMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- q = self._point
- p0 = copy(q)
- for i in range(scalar.bit_length() - 2, -1, -1):
- p0 = self._dbl(p0)
- p1 = self._add(p0, q)
- if scalar & (1 << i) != 0:
- p0 = p1
- if "scl" in self.formulas:
- p0 = self._scl(p0)
- return p0
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ q = self._point
+ p0 = copy(q)
+ for i in range(scalar.bit_length() - 2, -1, -1):
+ p0 = self._dbl(p0)
+ p1 = self._add(p0, q)
+ if scalar & (1 << i) != 0:
+ p0 = p1
+ if "scl" in self.formulas:
+ p0 = self._scl(p0)
+ return p0
@public
@@ -229,25 +243,26 @@ class LadderMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- q = self._point
- if self.complete:
- p0 = copy(self._group.neutral)
- p1 = self._point
- top = self._group.order.bit_length() - 1
- else:
- p0 = copy(q)
- p1 = self._dbl(q)
- top = scalar.bit_length() - 2
- for i in range(top, -1, -1):
- if scalar & (1 << i) == 0:
- p0, p1 = self._ladd(q, p0, p1)
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ q = self._point
+ if self.complete:
+ p0 = copy(self._group.neutral)
+ p1 = self._point
+ top = self._group.order.bit_length() - 1
else:
- p1, p0 = self._ladd(q, p1, p0)
- if "scl" in self.formulas:
- p0 = self._scl(p0)
- return p0
+ p0 = copy(q)
+ p1 = self._dbl(q)
+ top = scalar.bit_length() - 2
+ for i in range(top, -1, -1):
+ if scalar & (1 << i) == 0:
+ p0, p1 = self._ladd(q, p0, p1)
+ else:
+ p1, p0 = self._ladd(q, p1, p0)
+ if "scl" in self.formulas:
+ p0 = self._scl(p0)
+ return p0
@public
@@ -267,24 +282,25 @@ class SimpleLadderMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- if self.complete:
- top = self._group.order.bit_length() - 1
- else:
- top = scalar.bit_length() - 1
- p0 = copy(self._group.neutral)
- p1 = copy(self._point)
- for i in range(top, -1, -1):
- if scalar & (1 << i) == 0:
- p1 = self._add(p0, p1)
- p0 = self._dbl(p0)
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ if self.complete:
+ top = self._group.order.bit_length() - 1
else:
- p0 = self._add(p0, p1)
- p1 = self._dbl(p1)
- if "scl" in self.formulas:
- p0 = self._scl(p0)
- return p0
+ top = scalar.bit_length() - 1
+ p0 = copy(self._group.neutral)
+ p1 = copy(self._point)
+ for i in range(top, -1, -1):
+ if scalar & (1 << i) == 0:
+ p1 = self._add(p0, p1)
+ p0 = self._dbl(p0)
+ else:
+ p0 = self._add(p0, p1)
+ p1 = self._dbl(p1)
+ if "scl" in self.formulas:
+ p0 = self._scl(p0)
+ return p0
@public
@@ -304,25 +320,26 @@ class DifferentialLadderMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- if self.complete:
- top = self._group.order.bit_length() - 1
- else:
- top = scalar.bit_length() - 1
- q = self._point
- p0 = copy(self._group.neutral)
- p1 = copy(q)
- for i in range(top, -1, -1):
- if scalar & (1 << i) == 0:
- p1 = self._dadd(q, p0, p1)
- p0 = self._dbl(p0)
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ if self.complete:
+ top = self._group.order.bit_length() - 1
else:
- p0 = self._dadd(q, p0, p1)
- p1 = self._dbl(p1)
- if "scl" in self.formulas:
- p0 = self._scl(p0)
- return p0
+ top = scalar.bit_length() - 1
+ q = self._point
+ p0 = copy(self._group.neutral)
+ p1 = copy(q)
+ for i in range(top, -1, -1):
+ if scalar & (1 << i) == 0:
+ p1 = self._dadd(q, p0, p1)
+ p0 = self._dbl(p0)
+ else:
+ p0 = self._dadd(q, p0, p1)
+ p1 = self._dbl(p1)
+ if "scl" in self.formulas:
+ p0 = self._scl(p0)
+ return p0
@public
@@ -343,19 +360,20 @@ class BinaryNAFMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- bnaf = naf(scalar)
- q = copy(self._group.neutral)
- for val in bnaf:
- q = self._dbl(q)
- if val == 1:
- q = self._add(q, self._point)
- if val == -1:
- q = self._add(q, self._point_neg)
- if "scl" in self.formulas:
- q = self._scl(q)
- return q
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ bnaf = naf(scalar)
+ q = copy(self._group.neutral)
+ for val in bnaf:
+ q = self._dbl(q)
+ if val == 1:
+ q = self._add(q, self._point)
+ if val == -1:
+ q = self._add(q, self._point_neg)
+ if "scl" in self.formulas:
+ q = self._scl(q)
+ return q
@public
@@ -390,20 +408,21 @@ class WindowNAFMultiplier(ScalarMultiplier):
def multiply(self, scalar: int) -> Point:
if not self._initialized:
raise ValueError("ScalaMultiplier not initialized.")
- if scalar == 0:
- return copy(self._group.neutral)
- naf = wnaf(scalar, self.width)
- q = copy(self._group.neutral)
- for val in naf:
- q = self._dbl(q)
- if val > 0:
- q = self._add(q, self._points[val])
- elif val < 0:
- if self.precompute_negation:
- neg = self._points_neg[-val]
- else:
- neg = self._neg(self._points[-val])
- q = self._add(q, neg)
- if "scl" in self.formulas:
- q = self._scl(q)
- return q
+ with ScalarMultiplicationAction(self._point, scalar):
+ if scalar == 0:
+ return copy(self._group.neutral)
+ naf = wnaf(scalar, self.width)
+ q = copy(self._group.neutral)
+ for val in naf:
+ q = self._dbl(q)
+ if val > 0:
+ q = self._add(q, self._points[val])
+ elif val < 0:
+ if self.precompute_negation:
+ neg = self._points_neg[-val]
+ else:
+ neg = self._neg(self._points[-val])
+ q = self._add(q, neg)
+ if "scl" in self.formulas:
+ q = self._scl(q)
+ return q