diff options
Diffstat (limited to 'pyecsca/ec/mult.py')
| -rw-r--r-- | pyecsca/ec/mult.py | 163 |
1 files changed, 82 insertions, 81 deletions
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py index ab6cf7b..c766ff0 100644 --- a/pyecsca/ec/mult.py +++ b/pyecsca/ec/mult.py @@ -12,73 +12,79 @@ from .point import Point class ScalarMultiplier(object): - group: AbelianGroup + """ + A scalar multiplication algorithm. + + :param short_circuit: Whether the use of formulas will be guarded by short-circuit on inputs + of the point at infinity. + :param formulas: Formulas this instance will use. + """ + short_circuit: bool formulas: Mapping[str, Formula] + _group: AbelianGroup _point: Point = None - def __init__(self, group: AbelianGroup, **formulas: Optional[Formula]): - for formula in formulas.values(): - if formula is not None and formula.coordinate_model is not group.curve.coordinate_model: - raise ValueError - self.group = group + def __init__(self, short_circuit=True, **formulas: Optional[Formula]): + if len(set(formula.coordinate_model for formula in formulas.values() if + formula is not None)) != 1: + raise ValueError + self.short_circuit = short_circuit self.formulas = dict(filter(lambda pair: pair[1] is not None, formulas.items())) def _add(self, one: Point, other: Point) -> Point: if "add" not in self.formulas: raise NotImplementedError - if one == self.group.neutral: - return copy(other) - if other == self.group.neutral: - return copy(one) + if self.short_circuit: + if one == self._group.neutral: + return copy(other) + if other == self._group.neutral: + return copy(one) return \ - getcontext().execute(self.formulas["add"], one, other, **self.group.curve.parameters)[0] + getcontext().execute(self.formulas["add"], one, other, **self._group.curve.parameters)[0] def _dbl(self, point: Point) -> Point: if "dbl" not in self.formulas: raise NotImplementedError - if point == self.group.neutral: - return copy(point) - return getcontext().execute(self.formulas["dbl"], point, **self.group.curve.parameters)[0] + if self.short_circuit: + if point == self._group.neutral: + return copy(point) + return getcontext().execute(self.formulas["dbl"], point, **self._group.curve.parameters)[0] def _scl(self, point: Point) -> Point: if "scl" not in self.formulas: raise NotImplementedError - return getcontext().execute(self.formulas["scl"], point, **self.group.curve.parameters)[0] + return getcontext().execute(self.formulas["scl"], point, **self._group.curve.parameters)[0] def _ladd(self, start: Point, to_dbl: Point, to_add: Point) -> Tuple[Point, ...]: if "ladd" not in self.formulas: raise NotImplementedError return getcontext().execute(self.formulas["ladd"], start, to_dbl, to_add, - **self.group.curve.parameters) + **self._group.curve.parameters) def _dadd(self, start: Point, one: Point, other: Point) -> Point: if "dadd" not in self.formulas: raise NotImplementedError - if one == self.group.neutral: - return copy(other) - if other == self.group.neutral: - return copy(one) + if self.short_circuit: + if one == self._group.neutral: + return copy(other) + if other == self._group.neutral: + return copy(one) return getcontext().execute(self.formulas["dadd"], start, one, other, - **self.group.curve.parameters)[0] + **self._group.curve.parameters)[0] def _neg(self, point: Point) -> Point: if "neg" not in self.formulas: raise NotImplementedError - return getcontext().execute(self.formulas["neg"], point, **self.group.curve.parameters)[0] + return getcontext().execute(self.formulas["neg"], point, **self._group.curve.parameters)[0] - def init(self, point: Point): + def init(self, group: AbelianGroup, point: Point): + coord_model = set(self.formulas.values()).pop().coordinate_model + if group.curve.coordinate_model != coord_model or point.coordinate_model != coord_model: + raise ValueError + self._group = group self._point = point - def _init_multiply(self, point: Optional[Point]) -> Point: - if point is None: - if self._point is None: - raise ValueError - else: - if self._point != point: - self.init(point) - return self._point - - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: raise NotImplementedError @@ -91,16 +97,16 @@ class LTRMultiplier(ScalarMultiplier): """ always: bool - def __init__(self, group: AbelianGroup, add: AdditionFormula, dbl: DoublingFormula, + def __init__(self, add: AdditionFormula, dbl: DoublingFormula, scl: ScalingFormula = None, always: bool = False): - super().__init__(group, add=add, dbl=dbl, scl=scl) + super().__init__(add=add, dbl=dbl, scl=scl) self.always = always - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - q = self._init_multiply(point) - r = copy(self.group.neutral) + return copy(self._group.neutral) + q = self._point + r = copy(self._group.neutral) for i in range(scalar.bit_length() - 1, -1, -1): r = self._dbl(r) if scalar & (1 << i) != 0: @@ -121,16 +127,16 @@ class RTLMultiplier(ScalarMultiplier): """ always: bool - def __init__(self, group: AbelianGroup, add: AdditionFormula, dbl: DoublingFormula, + def __init__(self, add: AdditionFormula, dbl: DoublingFormula, scl: ScalingFormula = None, always: bool = False): - super().__init__(group, add=add, dbl=dbl, scl=scl) + super().__init__(add=add, dbl=dbl, scl=scl) self.always = always - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - q = self._init_multiply(point) - r = copy(self.group.neutral) + return copy(self._group.neutral) + q = self._point + r = copy(self._group.neutral) while scalar > 0: if scalar & 1 != 0: r = self._add(r, q) @@ -152,14 +158,13 @@ class CoronMultiplier(ScalarMultiplier): https://link.springer.com/content/pdf/10.1007/3-540-48059-5_25.pdf """ - def __init__(self, group: AbelianGroup, add: AdditionFormula, dbl: DoublingFormula, - scl: ScalingFormula = None): - super().__init__(group, add=add, dbl=dbl, scl=scl) + def __init__(self, add: AdditionFormula, dbl: DoublingFormula, scl: ScalingFormula = None): + super().__init__(add=add, dbl=dbl, scl=scl) - def multiply(self, scalar: int, point: Optional[Point] = None): + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - q = self._init_multiply(point) + return copy(self._group.neutral) + q = self._point p0 = copy(q) for i in range(scalar.bit_length() - 2, -1, -1): p0 = self._dbl(p0) @@ -177,14 +182,13 @@ class LadderMultiplier(ScalarMultiplier): Montgomery ladder multiplier, using a three input, two output ladder formula. """ - def __init__(self, group: AbelianGroup, ladd: LadderFormula, dbl: DoublingFormula, - scl: ScalingFormula = None): - super().__init__(group, ladd=ladd, dbl=dbl, scl=scl) + def __init__(self, ladd: LadderFormula, dbl: DoublingFormula, scl: ScalingFormula = None): + super().__init__(ladd=ladd, dbl=dbl, scl=scl) - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - q = self._init_multiply(point) + return copy(self._group.neutral) + q = self._point p0 = copy(q) p1 = self._dbl(q) for i in range(scalar.bit_length() - 2, -1, -1): @@ -204,22 +208,21 @@ class SimpleLadderMultiplier(ScalarMultiplier): """ _differential: bool = False - def __init__(self, group: AbelianGroup, - add: Union[AdditionFormula, DifferentialAdditionFormula], dbl: DoublingFormula, + def __init__(self, add: Union[AdditionFormula, DifferentialAdditionFormula], dbl: DoublingFormula, scl: ScalingFormula = None): if isinstance(add, AdditionFormula): - super().__init__(group, add=add, dbl=dbl, scl=scl) + super().__init__(add=add, dbl=dbl, scl=scl) elif isinstance(add, DifferentialAdditionFormula): - super().__init__(group, dadd=add, dbl=dbl, scl=scl) + super().__init__(dadd=add, dbl=dbl, scl=scl) self._differential = True else: raise ValueError - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - q = self._init_multiply(point) - p0 = copy(self.group.neutral) + return copy(self._group.neutral) + q = self._point + p0 = copy(self._group.neutral) p1 = copy(q) for i in range(scalar.bit_length() - 1, -1, -1): if scalar & (1 << i) == 0: @@ -246,20 +249,19 @@ class BinaryNAFMultiplier(ScalarMultiplier): """ _point_neg: Point - def __init__(self, group: AbelianGroup, add: AdditionFormula, dbl: DoublingFormula, + def __init__(self, add: AdditionFormula, dbl: DoublingFormula, neg: NegationFormula, scl: ScalingFormula = None): - super().__init__(group, add=add, dbl=dbl, neg=neg, scl=scl) + super().__init__(add=add, dbl=dbl, neg=neg, scl=scl) - def init(self, point: Point): - super().init(point) + def init(self, group: AbelianGroup, point: Point): + super().init(group, point) self._point_neg = self._neg(point) - def multiply(self, scalar: int, point: Optional[Point] = None) -> Point: + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - self._init_multiply(point) + return copy(self._group.neutral) bnaf = naf(scalar) - q = copy(self.group.neutral) + q = copy(self._group.neutral) for val in bnaf: q = self._dbl(q) if val == 1: @@ -281,15 +283,15 @@ class WindowNAFMultiplier(ScalarMultiplier): _precompute_neg: bool = False _width: int - def __init__(self, group: AbelianGroup, add: AdditionFormula, dbl: DoublingFormula, + def __init__(self, add: AdditionFormula, dbl: DoublingFormula, neg: NegationFormula, width: int, scl: ScalingFormula = None, precompute_negation: bool = False): - super().__init__(group, add=add, dbl=dbl, neg=neg, scl=scl) + super().__init__(add=add, dbl=dbl, neg=neg, scl=scl) self._width = width self._precompute_neg = precompute_negation - def init(self, point: Point): - self._point = point + def init(self, group: AbelianGroup, point: Point): + super().init(group, point) self._points = {} self._points_neg = {} current_point = point @@ -300,12 +302,11 @@ class WindowNAFMultiplier(ScalarMultiplier): self._points_neg[2 ** i - 1] = self._neg(current_point) current_point = self._add(current_point, double_point) - def multiply(self, scalar: int, point: Optional[Point] = None): + def multiply(self, scalar: int) -> Point: if scalar == 0: - return copy(self.group.neutral) - self._init_multiply(point) + return copy(self._group.neutral) naf = wnaf(scalar, self._width) - q = copy(self.group.neutral) + q = copy(self._group.neutral) for val in naf: q = self._dbl(q) if val > 0: |
