aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/ec/curve.py4
-rw-r--r--pyecsca/ec/mod.py14
-rw-r--r--pyecsca/ec/mult.py6
-rw-r--r--test/ec/test_mod.py30
4 files changed, 46 insertions, 8 deletions
diff --git a/pyecsca/ec/curve.py b/pyecsca/ec/curve.py
index 79939b7..d4cef5d 100644
--- a/pyecsca/ec/curve.py
+++ b/pyecsca/ec/curve.py
@@ -12,11 +12,13 @@ class EllipticCurve(object):
neutral: Point
def __init__(self, model: Type[CurveModel], coordinate_model: CoordinateModel,
- parameters: Mapping[str, int], neutral: Point = None):
+ parameters: Mapping[str, int], neutral: Point):
if coordinate_model not in model.coordinates.values():
raise ValueError
if set(model.parameter_names).symmetric_difference(parameters.keys()):
raise ValueError
+ if neutral.coordinate_model != coordinate_model:
+ raise ValueError
self.model = model
self.coordinate_model = coordinate_model
self.parameters = dict(parameters)
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index f75952d..bc1ebff 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -77,8 +77,6 @@ class Mod(object):
@check
def __mul__(self, other):
- if self.n != other.n:
- raise ValueError
return Mod((self.x * other.x) % self.n, self.n)
@check
@@ -94,12 +92,20 @@ class Mod(object):
return ~self * other
@check
+ def __floordiv__(self, other):
+ return self * ~other
+
+ @check
+ def __rfloordiv__(self, other):
+ return ~self * other
+
+ @check
def __div__(self, other):
- return self.__truediv__(other)
+ return self.__floordiv__(other)
@check
def __rdiv__(self, other):
- return self.__rtruediv__(other)
+ return self.__rfloordiv__(other)
@check
def __divmod__(self, divisor):
diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py
index a32a54b..9255342 100644
--- a/pyecsca/ec/mult.py
+++ b/pyecsca/ec/mult.py
@@ -1,5 +1,5 @@
from copy import copy
-from typing import Mapping, Tuple
+from typing import Mapping, Tuple, Optional
from .context import Context
from .curve import EllipticCurve
@@ -12,7 +12,7 @@ class ScalarMultiplier(object):
formulas: Mapping[str, Formula]
context: Context
- def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Formula):
+ def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Optional[Formula]):
for formula in formulas.values():
if formula is not None and formula.coordinate_model is not curve.coordinate_model:
raise ValueError
@@ -21,7 +21,7 @@ class ScalarMultiplier(object):
self.context = ctx
else:
self.context = Context()
- self.formulas = dict(formulas)
+ self.formulas = dict(filter(lambda pair: pair[1] is not None, formulas.items()))
def _add(self, one: Point, other: Point) -> Point:
if "add" not in self.formulas:
diff --git a/test/ec/test_mod.py b/test/ec/test_mod.py
new file mode 100644
index 0000000..6e4cfbf
--- /dev/null
+++ b/test/ec/test_mod.py
@@ -0,0 +1,30 @@
+from unittest import TestCase
+
+from pyecsca.ec.mod import Mod, gcd, extgcd
+
+
+class ModTests(TestCase):
+
+ def test_gcd(self):
+ self.assertEqual(gcd(15, 20), 5)
+ self.assertEqual(extgcd(15, 20), (-1, 1, 5))
+
+ def test_wrong_mod(self):
+ a = Mod(5, 7)
+ b = Mod(4, 11)
+ with self.assertRaises(ValueError):
+ a + b
+
+ def test_other(self):
+ a = Mod(5, 7)
+ b = Mod(3, 7)
+ self.assertEqual(int(-a), 2)
+ self.assertEqual(str(a), "5")
+ self.assertEqual(6 - a, Mod(1, 7))
+ self.assertNotEqual(a, b)
+ self.assertEqual(a / b, Mod(4, 7))
+ self.assertEqual(a // b, Mod(4, 7))
+ self.assertEqual(5 / b, Mod(4, 7))
+ self.assertEqual(5 // b, Mod(4, 7))
+ self.assertEqual(divmod(a, b), (Mod(1, 7), Mod(2, 7)))
+ self.assertNotEqual(a, 6) \ No newline at end of file