diff options
| -rw-r--r-- | pyecsca/ec/curve.py | 4 | ||||
| -rw-r--r-- | pyecsca/ec/mod.py | 14 | ||||
| -rw-r--r-- | pyecsca/ec/mult.py | 6 | ||||
| -rw-r--r-- | test/ec/test_mod.py | 30 |
4 files changed, 46 insertions, 8 deletions
diff --git a/pyecsca/ec/curve.py b/pyecsca/ec/curve.py index 79939b7..d4cef5d 100644 --- a/pyecsca/ec/curve.py +++ b/pyecsca/ec/curve.py @@ -12,11 +12,13 @@ class EllipticCurve(object): neutral: Point def __init__(self, model: Type[CurveModel], coordinate_model: CoordinateModel, - parameters: Mapping[str, int], neutral: Point = None): + parameters: Mapping[str, int], neutral: Point): if coordinate_model not in model.coordinates.values(): raise ValueError if set(model.parameter_names).symmetric_difference(parameters.keys()): raise ValueError + if neutral.coordinate_model != coordinate_model: + raise ValueError self.model = model self.coordinate_model = coordinate_model self.parameters = dict(parameters) diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py index f75952d..bc1ebff 100644 --- a/pyecsca/ec/mod.py +++ b/pyecsca/ec/mod.py @@ -77,8 +77,6 @@ class Mod(object): @check def __mul__(self, other): - if self.n != other.n: - raise ValueError return Mod((self.x * other.x) % self.n, self.n) @check @@ -94,12 +92,20 @@ class Mod(object): return ~self * other @check + def __floordiv__(self, other): + return self * ~other + + @check + def __rfloordiv__(self, other): + return ~self * other + + @check def __div__(self, other): - return self.__truediv__(other) + return self.__floordiv__(other) @check def __rdiv__(self, other): - return self.__rtruediv__(other) + return self.__rfloordiv__(other) @check def __divmod__(self, divisor): diff --git a/pyecsca/ec/mult.py b/pyecsca/ec/mult.py index a32a54b..9255342 100644 --- a/pyecsca/ec/mult.py +++ b/pyecsca/ec/mult.py @@ -1,5 +1,5 @@ from copy import copy -from typing import Mapping, Tuple +from typing import Mapping, Tuple, Optional from .context import Context from .curve import EllipticCurve @@ -12,7 +12,7 @@ class ScalarMultiplier(object): formulas: Mapping[str, Formula] context: Context - def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Formula): + def __init__(self, curve: EllipticCurve, ctx: Context = None, **formulas: Optional[Formula]): for formula in formulas.values(): if formula is not None and formula.coordinate_model is not curve.coordinate_model: raise ValueError @@ -21,7 +21,7 @@ class ScalarMultiplier(object): self.context = ctx else: self.context = Context() - self.formulas = dict(formulas) + self.formulas = dict(filter(lambda pair: pair[1] is not None, formulas.items())) def _add(self, one: Point, other: Point) -> Point: if "add" not in self.formulas: diff --git a/test/ec/test_mod.py b/test/ec/test_mod.py new file mode 100644 index 0000000..6e4cfbf --- /dev/null +++ b/test/ec/test_mod.py @@ -0,0 +1,30 @@ +from unittest import TestCase + +from pyecsca.ec.mod import Mod, gcd, extgcd + + +class ModTests(TestCase): + + def test_gcd(self): + self.assertEqual(gcd(15, 20), 5) + self.assertEqual(extgcd(15, 20), (-1, 1, 5)) + + def test_wrong_mod(self): + a = Mod(5, 7) + b = Mod(4, 11) + with self.assertRaises(ValueError): + a + b + + def test_other(self): + a = Mod(5, 7) + b = Mod(3, 7) + self.assertEqual(int(-a), 2) + self.assertEqual(str(a), "5") + self.assertEqual(6 - a, Mod(1, 7)) + self.assertNotEqual(a, b) + self.assertEqual(a / b, Mod(4, 7)) + self.assertEqual(a // b, Mod(4, 7)) + self.assertEqual(5 / b, Mod(4, 7)) + self.assertEqual(5 // b, Mod(4, 7)) + self.assertEqual(divmod(a, b), (Mod(1, 7), Mod(2, 7))) + self.assertNotEqual(a, 6)
\ No newline at end of file |
