diff options
Diffstat (limited to 'pyecsca/ec/mult.py')
| -rw-r--r-- | pyecsca/ec/mult.py | 72 |
1 files changed, 33 insertions, 39 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py index 9870296..a623fd0 100644 --- a/pyecsca/ec/mult.py +++ b/pyecsca/ec/mult.py @@ -12,7 +12,7 @@ class ScalarMultiplier(object): curve: EllipticCurve formulas: Mapping[str, Formula] context: Context - _point: Optional[Point] = None + _point: Point = None def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Optional[Formula]): for formula in formulas.values(): @@ -53,13 +53,22 @@ class ScalarMultiplier(object): **self.curve.parameters) def _neg(self, point: Point) -> Point: - #TODO + # TODO raise NotImplementedError def init(self, point: Point): - raise NotImplementedError + self._point = point + + def _init_multiply(self, point: Optional[Point]) -> Point: + if point is None: + if self._point is None: + raise ValueError + else: + if self._point != point: + self.init(point) + return self._point - def multiply(self, scalar: int, point: Optional[Point]) -> Point: + def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: raise NotImplementedError @@ -72,19 +81,15 @@ class LTRMultiplier(ScalarMultiplier): super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl) self.always = always - def init(self, point: Point): - self._point = point - - def multiply(self, scalar: int, point: Optional[Point]) -> Point: - if point is not None and self._point != point: - self.init(point) + def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + q = self._init_multiply(point) r = copy(self.curve.neutral) for i in range(scalar.bit_length(), -1, -1): r = self._dbl(r) if scalar & (1 << i) != 0: - r = self._add(r, self._point) + r = self._add(r, q) elif self.always: - self._add(r, self._point) + self._add(r, q) if "scl" in self.formulas: r = self._scl(r) return r @@ -99,20 +104,15 @@ class RTLMultiplier(ScalarMultiplier): super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl) self.always = always - def init(self, point: Point): - self._point = point - - def multiply(self, scalar: int, point: Optional[Point]) -> Point: - if point is not None and self._point != point: - self.init(point) + def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + q = self._init_multiply(point) r = copy(self.curve.neutral) - point = self._point while scalar > 0: if scalar & 1 != 0: - r = self._add(r, point) + r = self._add(r, q) elif self.always: - self._add(r, point) - point = self._dbl(point) + self._add(r, q) + q = self._dbl(q) scalar >>= 1 if "scl" in self.formulas: r = self._scl(r) @@ -125,19 +125,15 @@ class LadderMultiplier(ScalarMultiplier): ctx: Context = None): super().__init__(curve, ctx, ladd=ladd, scl=scl) - def init(self, point: Point): - self._point = point - - def multiply(self, scalar: int, point: Optional[Point]) -> Point: - if point is not None and self._point != point: - self.init(point) - p0 = copy(self._point) - p1 = self._ladd(self.curve.neutral, self._point, self._point)[1] + def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + q = self._init_multiply(point) + p0 = copy(q) + p1 = self._ladd(self.curve.neutral, q, q)[1] for i in range(scalar.bit_length(), -1, -1): if scalar & i != 0: - p0, p1 = self._ladd(self._point, p1, p0) + p0, p1 = self._ladd(q, p1, p0) else: - p0, p1 = self._ladd(self._point, p0, p1) + p0, p1 = self._ladd(q, p0, p1) if "scl" in self.formulas: p0 = self._scl(p0) return p0 @@ -152,12 +148,11 @@ class BinaryNAFMultiplier(ScalarMultiplier): super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl) def init(self, point: Point): - self._point = point + super().init(point) self._point_neg = self._neg(point) - def multiply(self, scalar: int, point: Optional[Point]) -> Point: - if point is not None and self._point != point: - self.init(point) + def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + self._init_multiply(point) bnaf = naf(scalar) q = copy(self.curve.neutral) for val in bnaf: @@ -183,9 +178,8 @@ class WindowNAFMultiplier(ScalarMultiplier): self._point = point # TODO: precompute {1, 3, 5, upto 2^(w-1)-1} - def multiply(self, scalar: int, point: Optional[Point]): - if point is not None and self._point != point: - self.init(point) + def multiply(self, scalar: int, point: Optional[Point] = None): + self._init_multiply(point) naf = wnaf(scalar, self._width) q = copy(self.curve.neutral) for val in naf: |
