diff options
Diffstat (limited to 'pyecsca/ec/mult/window.py')
| -rw-r--r-- | pyecsca/ec/mult/window.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/pyecsca/ec/mult/window.py b/pyecsca/ec/mult/window.py index c200cc5..6fbee24 100644 --- a/pyecsca/ec/mult/window.py +++ b/pyecsca/ec/mult/window.py @@ -10,7 +10,7 @@ from pyecsca.ec.mult.base import ( ScalarMultiplicationAction, PrecomputationAction, ProcessingDirection, - AccumulatorMultiplier, + AccumulatorMultiplier, PrecompMultiplier, ) from pyecsca.ec.formula import ( AdditionFormula, @@ -28,7 +28,7 @@ from pyecsca.ec.scalar import ( @public -class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier): +class SlidingWindowMultiplier(AccumulatorMultiplier, PrecompMultiplier, ScalarMultiplier): """ Sliding window scalar multiplier. @@ -91,7 +91,7 @@ class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier): return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, width={self.width}, recoding_direction={self.recoding_direction.name}, accumulation_order={self.accumulation_order.name})" def init(self, params: DomainParameters, point: Point): - with PrecomputationAction(params, point): + with PrecomputationAction(params, point) as action: super().init(params, point) self._points = {} current_point = point @@ -99,6 +99,7 @@ class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier): for i in range(0, 2 ** (self.width - 1)): self._points[2 * i + 1] = current_point current_point = self._add(current_point, double_point) + action.exit(self._points) def multiply(self, scalar: int) -> Point: if not self._initialized: @@ -121,7 +122,7 @@ class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier): @public -class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier): +class FixedWindowLTRMultiplier(AccumulatorMultiplier, PrecompMultiplier, ScalarMultiplier): """ Like LTRMultiplier, but m-ary, not binary. @@ -186,7 +187,7 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier): return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, m={self.m}, accumulation_order={self.accumulation_order.name})" def init(self, params: DomainParameters, point: Point): - with PrecomputationAction(params, point): + with PrecomputationAction(params, point) as action: super().init(params, point) double_point = self._dbl(point) self._points = {1: point, 2: double_point} @@ -194,6 +195,7 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier): for i in range(3, self.m): current_point = self._add(current_point, point) self._points[i] = current_point + action.exit(self._points) def _mult_m(self, point: Point) -> Point: if self.m & (self.m - 1) == 0: @@ -229,7 +231,7 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier): @public -class WindowBoothMultiplier(AccumulatorMultiplier, ScalarMultiplier): +class WindowBoothMultiplier(AccumulatorMultiplier, PrecompMultiplier, ScalarMultiplier): """ :param short_circuit: Whether the use of formulas will be guarded by short-circuit on inputs @@ -297,7 +299,7 @@ class WindowBoothMultiplier(AccumulatorMultiplier, ScalarMultiplier): return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, width={self.width}, precompute_negation={self.precompute_negation}, accumulation_order={self.accumulation_order.name})" def init(self, params: DomainParameters, point: Point): - with PrecomputationAction(params, point): + with PrecomputationAction(params, point) as actions: super().init(params, point) double_point = self._dbl(point) self._points = {1: point, 2: double_point} @@ -309,6 +311,10 @@ class WindowBoothMultiplier(AccumulatorMultiplier, ScalarMultiplier): self._points[i] = current_point if self.precompute_negation: self._points_neg[i] = self._neg(current_point) + if self.precompute_negation: + actions.exit({**self._points, **self._points_neg}) + else: + actions.exit(self._points) def multiply(self, scalar: int) -> Point: if not self._initialized: |
