diff options
| author | J08nY | 2020-02-11 20:44:45 +0100 |
|---|---|---|
| committer | J08nY | 2020-02-11 20:44:45 +0100 |
| commit | 11bd56b296f1620932f098a6037f0807e7f6616f (patch) | |
| tree | 2a791114a710ab49af523cf1ba2144646ef9ad90 /pyecsca/ec/mult.py | |
| parent | 4e2bd346baf2db39391deb49e9bdb9d89f94101a (diff) | |
| download | pyecsca-11bd56b296f1620932f098a6037f0807e7f6616f.tar.gz pyecsca-11bd56b296f1620932f098a6037f0807e7f6616f.tar.zst pyecsca-11bd56b296f1620932f098a6037f0807e7f6616f.zip | |
Diffstat (limited to 'pyecsca/ec/mult.py')
| -rw-r--r-- | pyecsca/ec/mult.py | 301 |
1 files changed, 160 insertions, 141 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py index 4f937fa..4c5b6d9 100644 --- a/pyecsca/ec/mult.py +++ b/pyecsca/ec/mult.py @@ -1,16 +1,31 @@ from copy import copy -from typing import Mapping, Tuple, Optional, MutableMapping, Union, ClassVar, Set, Type +from typing import Mapping, Tuple, Optional, MutableMapping, ClassVar, Set, Type from public import public -from .context import getcontext +from .context import Action from .formula import (Formula, AdditionFormula, DoublingFormula, DifferentialAdditionFormula, ScalingFormula, LadderFormula, NegationFormula) -from .params import DomainParameters from .naf import naf, wnaf +from .params import DomainParameters from .point import Point +@public +class ScalarMultiplicationAction(Action): + """A scalar multiplication of a point on a curve by a scalar.""" + point: Point + scalar: int + + def __init__(self, point: Point, scalar: int): + super().__init__() + self.point = point + self.scalar = scalar + + def __repr__(self): + return f"{self.__class__.__name__}({self.point}, {self.scalar})" + + class ScalarMultiplier(object): """ A scalar multiplication algorithm. @@ -42,9 +57,7 @@ class ScalarMultiplier(object): return copy(other) if other == self._group.neutral: return copy(one) - return \ - getcontext().execute(self.formulas["add"], one, other, **self._group.curve.parameters)[ - 0] + return self.formulas["add"](one, other, **self._group.curve.parameters)[0] def _dbl(self, point: Point) -> Point: if "dbl" not in self.formulas: @@ -52,12 +65,12 @@ class ScalarMultiplier(object): if self.short_circuit: if point == self._group.neutral: return copy(point) - return getcontext().execute(self.formulas["dbl"], point, **self._group.curve.parameters)[0] + return self.formulas["dbl"](point, **self._group.curve.parameters)[0] def _scl(self, point: Point) -> Point: if "scl" not in self.formulas: raise NotImplementedError - return getcontext().execute(self.formulas["scl"], point, **self._group.curve.parameters)[0] + return self.formulas["scl"](point, **self._group.curve.parameters)[0] def _ladd(self, start: Point, to_dbl: Point, to_add: Point) -> Tuple[Point, ...]: if "ladd" not in self.formulas: @@ -67,8 +80,7 @@ class ScalarMultiplier(object): return to_dbl, to_add if to_add == self._group.neutral: return self._dbl(to_dbl), to_dbl - return getcontext().execute(self.formulas["ladd"], start, to_dbl, to_add, - **self._group.curve.parameters) + return self.formulas["ladd"](start, to_dbl, to_add, **self._group.curve.parameters) def _dadd(self, start: Point, one: Point, other: Point) -> Point: if "dadd" not in self.formulas: @@ -78,13 +90,12 @@ class ScalarMultiplier(object): return copy(other) if other == self._group.neutral: return copy(one) - return getcontext().execute(self.formulas["dadd"], start, one, other, - **self._group.curve.parameters)[0] + return self.formulas["dadd"](start, one, other, **self._group.curve.parameters)[0] def _neg(self, point: Point) -> Point: if "neg" not in self.formulas: raise NotImplementedError - return getcontext().execute(self.formulas["neg"], point, **self._group.curve.parameters)[0] + return self.formulas["neg"](point, **self._group.curve.parameters)[0] def init(self, group: DomainParameters, point: Point): """Initialize the scalar multiplier with a group and a point.""" @@ -122,25 +133,26 @@ class LTRMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - if self.complete: - q = self._point - r = copy(self._group.neutral) - top = self._group.order.bit_length() - 1 - else: - q = self._dbl(self._point) - r = copy(self._point) - top = scalar.bit_length() - 2 - for i in range(top, -1, -1): - r = self._dbl(r) - if scalar & (1 << i) != 0: - r = self._add(r, q) - elif self.always: - self._add(r, q) - if "scl" in self.formulas: - r = self._scl(r) - return r + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + if self.complete: + q = self._point + r = copy(self._group.neutral) + top = self._group.order.bit_length() - 1 + else: + q = self._dbl(self._point) + r = copy(self._point) + top = scalar.bit_length() - 2 + for i in range(top, -1, -1): + r = self._dbl(r) + if scalar & (1 << i) != 0: + r = self._add(r, q) + elif self.always: + self._add(r, q) + if "scl" in self.formulas: + r = self._scl(r) + return r @public @@ -162,20 +174,21 @@ class RTLMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - q = self._point - r = copy(self._group.neutral) - while scalar > 0: - if scalar & 1 != 0: - r = self._add(r, q) - elif self.always: - self._add(r, q) - q = self._dbl(q) - scalar >>= 1 - if "scl" in self.formulas: - r = self._scl(r) - return r + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + q = self._point + r = copy(self._group.neutral) + while scalar > 0: + if scalar & 1 != 0: + r = self._add(r, q) + elif self.always: + self._add(r, q) + q = self._dbl(q) + scalar >>= 1 + if "scl" in self.formulas: + r = self._scl(r) + return r class CoronMultiplier(ScalarMultiplier): @@ -196,18 +209,19 @@ class CoronMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - q = self._point - p0 = copy(q) - for i in range(scalar.bit_length() - 2, -1, -1): - p0 = self._dbl(p0) - p1 = self._add(p0, q) - if scalar & (1 << i) != 0: - p0 = p1 - if "scl" in self.formulas: - p0 = self._scl(p0) - return p0 + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + q = self._point + p0 = copy(q) + for i in range(scalar.bit_length() - 2, -1, -1): + p0 = self._dbl(p0) + p1 = self._add(p0, q) + if scalar & (1 << i) != 0: + p0 = p1 + if "scl" in self.formulas: + p0 = self._scl(p0) + return p0 @public @@ -229,25 +243,26 @@ class LadderMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - q = self._point - if self.complete: - p0 = copy(self._group.neutral) - p1 = self._point - top = self._group.order.bit_length() - 1 - else: - p0 = copy(q) - p1 = self._dbl(q) - top = scalar.bit_length() - 2 - for i in range(top, -1, -1): - if scalar & (1 << i) == 0: - p0, p1 = self._ladd(q, p0, p1) + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + q = self._point + if self.complete: + p0 = copy(self._group.neutral) + p1 = self._point + top = self._group.order.bit_length() - 1 else: - p1, p0 = self._ladd(q, p1, p0) - if "scl" in self.formulas: - p0 = self._scl(p0) - return p0 + p0 = copy(q) + p1 = self._dbl(q) + top = scalar.bit_length() - 2 + for i in range(top, -1, -1): + if scalar & (1 << i) == 0: + p0, p1 = self._ladd(q, p0, p1) + else: + p1, p0 = self._ladd(q, p1, p0) + if "scl" in self.formulas: + p0 = self._scl(p0) + return p0 @public @@ -267,24 +282,25 @@ class SimpleLadderMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - if self.complete: - top = self._group.order.bit_length() - 1 - else: - top = scalar.bit_length() - 1 - p0 = copy(self._group.neutral) - p1 = copy(self._point) - for i in range(top, -1, -1): - if scalar & (1 << i) == 0: - p1 = self._add(p0, p1) - p0 = self._dbl(p0) + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + if self.complete: + top = self._group.order.bit_length() - 1 else: - p0 = self._add(p0, p1) - p1 = self._dbl(p1) - if "scl" in self.formulas: - p0 = self._scl(p0) - return p0 + top = scalar.bit_length() - 1 + p0 = copy(self._group.neutral) + p1 = copy(self._point) + for i in range(top, -1, -1): + if scalar & (1 << i) == 0: + p1 = self._add(p0, p1) + p0 = self._dbl(p0) + else: + p0 = self._add(p0, p1) + p1 = self._dbl(p1) + if "scl" in self.formulas: + p0 = self._scl(p0) + return p0 @public @@ -304,25 +320,26 @@ class DifferentialLadderMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - if self.complete: - top = self._group.order.bit_length() - 1 - else: - top = scalar.bit_length() - 1 - q = self._point - p0 = copy(self._group.neutral) - p1 = copy(q) - for i in range(top, -1, -1): - if scalar & (1 << i) == 0: - p1 = self._dadd(q, p0, p1) - p0 = self._dbl(p0) + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + if self.complete: + top = self._group.order.bit_length() - 1 else: - p0 = self._dadd(q, p0, p1) - p1 = self._dbl(p1) - if "scl" in self.formulas: - p0 = self._scl(p0) - return p0 + top = scalar.bit_length() - 1 + q = self._point + p0 = copy(self._group.neutral) + p1 = copy(q) + for i in range(top, -1, -1): + if scalar & (1 << i) == 0: + p1 = self._dadd(q, p0, p1) + p0 = self._dbl(p0) + else: + p0 = self._dadd(q, p0, p1) + p1 = self._dbl(p1) + if "scl" in self.formulas: + p0 = self._scl(p0) + return p0 @public @@ -343,19 +360,20 @@ class BinaryNAFMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - bnaf = naf(scalar) - q = copy(self._group.neutral) - for val in bnaf: - q = self._dbl(q) - if val == 1: - q = self._add(q, self._point) - if val == -1: - q = self._add(q, self._point_neg) - if "scl" in self.formulas: - q = self._scl(q) - return q + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + bnaf = naf(scalar) + q = copy(self._group.neutral) + for val in bnaf: + q = self._dbl(q) + if val == 1: + q = self._add(q, self._point) + if val == -1: + q = self._add(q, self._point_neg) + if "scl" in self.formulas: + q = self._scl(q) + return q @public @@ -390,20 +408,21 @@ class WindowNAFMultiplier(ScalarMultiplier): def multiply(self, scalar: int) -> Point: if not self._initialized: raise ValueError("ScalaMultiplier not initialized.") - if scalar == 0: - return copy(self._group.neutral) - naf = wnaf(scalar, self.width) - q = copy(self._group.neutral) - for val in naf: - q = self._dbl(q) - if val > 0: - q = self._add(q, self._points[val]) - elif val < 0: - if self.precompute_negation: - neg = self._points_neg[-val] - else: - neg = self._neg(self._points[-val]) - q = self._add(q, neg) - if "scl" in self.formulas: - q = self._scl(q) - return q + with ScalarMultiplicationAction(self._point, scalar): + if scalar == 0: + return copy(self._group.neutral) + naf = wnaf(scalar, self.width) + q = copy(self._group.neutral) + for val in naf: + q = self._dbl(q) + if val > 0: + q = self._add(q, self._points[val]) + elif val < 0: + if self.precompute_negation: + neg = self._points_neg[-val] + else: + neg = self._neg(self._points[-val]) + q = self._add(q, neg) + if "scl" in self.formulas: + q = self._scl(q) + return q |
