aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/ec/mult.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec/mult.py')
-rw-r--r--pyecsca/ec/mult.py72
1 files changed, 33 insertions, 39 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py
index 9870296..a623fd0 100644
--- a/pyecsca/ec/mult.py
+++ b/pyecsca/ec/mult.py
@@ -12,7 +12,7 @@ class ScalarMultiplier(object):
curve: EllipticCurve
formulas: Mapping[str, Formula]
context: Context
- _point: Optional[Point] = None
+ _point: Point = None
def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Optional[Formula]):
for formula in formulas.values():
@@ -53,13 +53,22 @@ class ScalarMultiplier(object):
**self.curve.parameters)
def _neg(self, point: Point) -> Point:
- #TODO
+ # TODO
raise NotImplementedError
def init(self, point: Point):
- raise NotImplementedError
+ self._point = point
+
+ def _init_multiply(self, point: Optional[Point]) -> Point:
+ if point is None:
+ if self._point is None:
+ raise ValueError
+ else:
+ if self._point != point:
+ self.init(point)
+ return self._point
- def multiply(self, scalar: int, point: Optional[Point]) -> Point:
+ def multiply(self, scalar: int, point: Optional[Point] = None) -> Point:
raise NotImplementedError
@@ -72,19 +81,15 @@ class LTRMultiplier(ScalarMultiplier):
super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl)
self.always = always
- def init(self, point: Point):
- self._point = point
-
- def multiply(self, scalar: int, point: Optional[Point]) -> Point:
- if point is not None and self._point != point:
- self.init(point)
+ def multiply(self, scalar: int, point: Optional[Point] = None) -> Point:
+ q = self._init_multiply(point)
r = copy(self.curve.neutral)
for i in range(scalar.bit_length(), -1, -1):
r = self._dbl(r)
if scalar & (1 << i) != 0:
- r = self._add(r, self._point)
+ r = self._add(r, q)
elif self.always:
- self._add(r, self._point)
+ self._add(r, q)
if "scl" in self.formulas:
r = self._scl(r)
return r
@@ -99,20 +104,15 @@ class RTLMultiplier(ScalarMultiplier):
super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl)
self.always = always
- def init(self, point: Point):
- self._point = point
-
- def multiply(self, scalar: int, point: Optional[Point]) -> Point:
- if point is not None and self._point != point:
- self.init(point)
+ def multiply(self, scalar: int, point: Optional[Point] = None) -> Point:
+ q = self._init_multiply(point)
r = copy(self.curve.neutral)
- point = self._point
while scalar > 0:
if scalar & 1 != 0:
- r = self._add(r, point)
+ r = self._add(r, q)
elif self.always:
- self._add(r, point)
- point = self._dbl(point)
+ self._add(r, q)
+ q = self._dbl(q)
scalar >>= 1
if "scl" in self.formulas:
r = self._scl(r)
@@ -125,19 +125,15 @@ class LadderMultiplier(ScalarMultiplier):
ctx: Context = None):
super().__init__(curve, ctx, ladd=ladd, scl=scl)
- def init(self, point: Point):
- self._point = point
-
- def multiply(self, scalar: int, point: Optional[Point]) -> Point:
- if point is not None and self._point != point:
- self.init(point)
- p0 = copy(self._point)
- p1 = self._ladd(self.curve.neutral, self._point, self._point)[1]
+ def multiply(self, scalar: int, point: Optional[Point] = None) -> Point:
+ q = self._init_multiply(point)
+ p0 = copy(q)
+ p1 = self._ladd(self.curve.neutral, q, q)[1]
for i in range(scalar.bit_length(), -1, -1):
if scalar & i != 0:
- p0, p1 = self._ladd(self._point, p1, p0)
+ p0, p1 = self._ladd(q, p1, p0)
else:
- p0, p1 = self._ladd(self._point, p0, p1)
+ p0, p1 = self._ladd(q, p0, p1)
if "scl" in self.formulas:
p0 = self._scl(p0)
return p0
@@ -152,12 +148,11 @@ class BinaryNAFMultiplier(ScalarMultiplier):
super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl)
def init(self, point: Point):
- self._point = point
+ super().init(point)
self._point_neg = self._neg(point)
- def multiply(self, scalar: int, point: Optional[Point]) -> Point:
- if point is not None and self._point != point:
- self.init(point)
+ def multiply(self, scalar: int, point: Optional[Point] = None) -> Point:
+ self._init_multiply(point)
bnaf = naf(scalar)
q = copy(self.curve.neutral)
for val in bnaf:
@@ -183,9 +178,8 @@ class WindowNAFMultiplier(ScalarMultiplier):
self._point = point
# TODO: precompute {1, 3, 5, upto 2^(w-1)-1}
- def multiply(self, scalar: int, point: Optional[Point]):
- if point is not None and self._point != point:
- self.init(point)
+ def multiply(self, scalar: int, point: Optional[Point] = None):
+ self._init_multiply(point)
naf = wnaf(scalar, self._width)
q = copy(self.curve.neutral)
for val in naf: