aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2018-12-14 19:10:19 +0100
committerJ08nY2019-03-21 11:00:14 +0100
commitba8212dbc9cee4c098838534c096486ad5bf759a (patch)
tree7d17d961103bf4bfad6de82e8b76bb663f82830f
parent251811d90066e561b99b6580838abc20eaaa2009 (diff)
downloadpyecsca-ba8212dbc9cee4c098838534c096486ad5bf759a.tar.gz
pyecsca-ba8212dbc9cee4c098838534c096486ad5bf759a.tar.zst
pyecsca-ba8212dbc9cee4c098838534c096486ad5bf759a.zip
Fix basic scalar multipliers.
-rw-r--r--pyecsca/ec/context.py9
-rw-r--r--pyecsca/ec/mod.py8
-rw-r--r--pyecsca/ec/mult.py54
-rw-r--r--test/ec/test_mult.py44
4 files changed, 79 insertions, 36 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py
index a65f461..b4667cf 100644
--- a/pyecsca/ec/context.py
+++ b/pyecsca/ec/context.py
@@ -7,17 +7,14 @@ from .point import Point
class Context(object):
intermediates: List[Tuple[str, Mod]]
+ actions: List[Tuple[Formula, Tuple[Point, ...]]]
def __init__(self):
self.intermediates = []
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- pass
+ self.actions = []
def execute(self, formula: Formula, *points: Point, **params: Mod) -> Point:
+ self.actions.append((formula, tuple(points)))
coords = {}
for i, point in enumerate(points):
if point.coordinate_model != formula.coordinate_model:
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index 1deb6de..f75952d 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -109,6 +109,14 @@ class Mod(object):
def __int__(self):
return self.x
+ def __eq__(self, other):
+ if type(other) is not Mod:
+ return False
+ return self.x == other.x and self.n == other.n
+
+ def __ne__(self, other):
+ return not self == other
+
def __repr__(self):
return str(self.x)
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py
index 6bf40ec..4dd0f5b 100644
--- a/pyecsca/ec/mult.py
+++ b/pyecsca/ec/mult.py
@@ -23,6 +23,27 @@ class ScalarMultiplier(object):
self.context = Context()
self.formulas = dict(formulas)
+ def _add(self, one: Point, other: Point) -> Point:
+ if "add" not in self.formulas:
+ raise NotImplementedError
+ if one == self.curve.neutral:
+ return copy(other)
+ if other == self.curve.neutral:
+ return copy(one)
+ return self.context.execute(self.formulas["add"], one, other, **self.curve.parameters)
+
+ def _dbl(self, point: Point) -> Point:
+ if "dbl" not in self.formulas:
+ raise NotImplementedError
+ if point == self.curve.neutral:
+ return copy(point)
+ return self.context.execute(self.formulas["dbl"], point, **self.curve.parameters)
+
+ def _scl(self, point: Point) -> Point:
+ if "scl" not in self.formulas:
+ raise NotImplementedError
+ return self.context.execute(self.formulas["scl"], point, **self.curve.parameters)
+
def multiply(self, scalar: int, point: Point) -> Point:
raise NotImplementedError
@@ -37,35 +58,36 @@ class LTRMultiplier(ScalarMultiplier):
self.always = always
def multiply(self, scalar: int, point: Point) -> Point:
- pass
+ r = copy(self.curve.neutral)
+ for i in range(scalar.bit_length(), -1, -1):
+ r = self._dbl(r)
+ if scalar & (1 << i) != 0:
+ r = self._add(r, point)
+ elif self.always:
+ self._add(r, point)
+ if "scl" in self.formulas:
+ r = self._scl(r)
+ return r
class RTLMultiplier(ScalarMultiplier):
always: bool
- scale: bool
def __init__(self, curve: EllipticCurve, add: AdditionFormula, dbl: DoublingFormula,
scl: ScalingFormula = None,
- ctx: Context = None, scale: bool = True, always: bool = False):
+ ctx: Context = None, always: bool = False):
super().__init__(curve, ctx, add=add, dbl=dbl, scl=scl)
self.always = always
- self.scale = scale
def multiply(self, scalar: int, point: Point) -> Point:
- q = copy(point)
r = copy(self.curve.neutral)
while scalar > 0:
- q = self.context.execute(self.formulas["dbl"], q, **self.curve.parameters)
- if self.always:
- tmp = self.context.execute(self.formulas["add"], r, q, **self.curve.parameters)
- else:
- if r == self.curve.neutral:
- tmp = copy(q)
- else:
- tmp = self.context.execute(self.formulas["add"], r, q, **self.curve.parameters)
if scalar & 1 != 0:
- r = tmp
+ r = self._add(r, point)
+ elif self.always:
+ self._add(r, point)
+ point = self._dbl(point)
scalar >>= 1
- if self.scale:
- r = self.context.execute(self.formulas["scl"], r, **self.curve.parameters)
+ if "scl" in self.formulas:
+ r = self._scl(r)
return r
diff --git a/test/ec/test_mult.py b/test/ec/test_mult.py
index 6fadec7..ec26335 100644
--- a/test/ec/test_mult.py
+++ b/test/ec/test_mult.py
@@ -1,25 +1,41 @@
from unittest import TestCase
-from pyecsca.ec.context import Context
from pyecsca.ec.curve import EllipticCurve
from pyecsca.ec.mod import Mod
from pyecsca.ec.model import ShortWeierstrassModel
-from pyecsca.ec.mult import RTLMultiplier
+from pyecsca.ec.mult import LTRMultiplier, RTLMultiplier
from pyecsca.ec.point import Point
class ScalarMultiplierTests(TestCase):
+ def setUp(self):
+ self.p = 0xfffffffdffffffffffffffffffffffff
+ self.coords = ShortWeierstrassModel.coordinates["projective"]
+ self.secp128r1 = EllipticCurve(ShortWeierstrassModel, self.coords,
+ dict(a=0xfffffffdfffffffffffffffffffffffc,
+ b=0xe87579c11079f43dd824993c2cee5ed3),
+ Point(self.coords, X=Mod(0, self.p), Y=Mod(1, self.p),
+ Z=Mod(0, self.p)))
+ self.base = Point(self.coords, X=Mod(0x161ff7528b899b2d0c28607ca52c5b86, self.p),
+ Y=Mod(0xcf5ac8395bafeb13c02da292dded7a83, self.p),
+ Z=Mod(1, self.p))
+
def test_rtl_simple(self):
- p = 0xfffffffdffffffffffffffffffffffff
- coords = ShortWeierstrassModel.coordinates["projective"]
- curve = EllipticCurve(ShortWeierstrassModel, coords,
- dict(a=0xfffffffdfffffffffffffffffffffffc,
- b=0xe87579c11079f43dd824993c2cee5ed3),
- Point(coords, X=Mod(0, p), Y=Mod(1, p), Z=Mod(0, p)))
- with Context() as ctx:
- mult = RTLMultiplier(curve, coords.formulas["add-1998-cmo"],
- coords.formulas["dbl-1998-cmo"], coords.formulas["z"], ctx=ctx)
- mult.multiply(10, Point(coords, X=Mod(0x161ff7528b899b2d0c28607ca52c5b86, p),
- Y=Mod(0xcf5ac8395bafeb13c02da292dded7a83, p),
- Z=Mod(1, p)))
+ mult = RTLMultiplier(self.secp128r1, self.coords.formulas["add-1998-cmo"],
+ self.coords.formulas["dbl-1998-cmo"], self.coords.formulas["z"])
+ mult.multiply(10, self.base)
+
+ def test_ltr_simple(self):
+ mult = LTRMultiplier(self.secp128r1, self.coords.formulas["add-1998-cmo"],
+ self.coords.formulas["dbl-1998-cmo"], self.coords.formulas["z"])
+ mult.multiply(10, self.base)
+
+ def test_basic_multipliers(self):
+ ltr = LTRMultiplier(self.secp128r1, self.coords.formulas["add-1998-cmo"],
+ self.coords.formulas["dbl-1998-cmo"], self.coords.formulas["z"])
+ res_ltr = ltr.multiply(10, self.base)
+ rtl = RTLMultiplier(self.secp128r1, self.coords.formulas["add-1998-cmo"],
+ self.coords.formulas["dbl-1998-cmo"], self.coords.formulas["z"])
+ res_rtl = rtl.multiply(10, self.base)
+ self.assertEqual(res_ltr, res_rtl)