aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/ec/mod/base.py58
-rw-r--r--pyecsca/ec/mod/flint.py2
-rw-r--r--pyecsca/ec/mod/gmp.py2
-rw-r--r--pyecsca/ec/mod/raw.py2
4 files changed, 35 insertions, 29 deletions
diff --git a/pyecsca/ec/mod/base.py b/pyecsca/ec/mod/base.py
index 6555ba1..12020d6 100644
--- a/pyecsca/ec/mod/base.py
+++ b/pyecsca/ec/mod/base.py
@@ -3,7 +3,7 @@ import secrets
from functools import lru_cache, wraps
from public import public
-from typing import Tuple, Any, Dict, Type, Set, TypeVar
+from typing import Tuple, Any, Dict, Type, Set, TypeVar, Generic
from pyecsca.ec.context import ResultAction
from pyecsca.misc.cfg import getconfig
@@ -238,7 +238,7 @@ _mod_order = ["gmp", "flint", "python"]
@public
-class Mod:
+class Mod(Generic[M]):
"""
An element x of ℤₙ.
@@ -280,25 +280,25 @@ class Mod:
raise TypeError("Abstract")
@_check
- def __add__(self, other) -> "Mod":
+ def __add__(self: M, other) -> M:
return self.__class__((self.x + other.x) % self.n, self.n)
@_check
- def __radd__(self, other) -> "Mod":
+ def __radd__(self: M, other) -> M:
return self + other
@_check
- def __sub__(self, other) -> "Mod":
+ def __sub__(self: M, other) -> M:
return self.__class__((self.x - other.x) % self.n, self.n)
@_check
- def __rsub__(self, other) -> "Mod":
+ def __rsub__(self: M, other) -> M:
return -self + other
- def __neg__(self) -> "Mod":
+ def __neg__(self: M) -> M:
return self.__class__(self.n - self.x, self.n)
- def bit_length(self):
+ def bit_length(self: M) -> int:
"""
Compute the bit length of this element (in its positive integer representation).
@@ -306,7 +306,7 @@ class Mod:
"""
raise NotImplementedError
- def inverse(self) -> "Mod":
+ def inverse(self: M) -> M:
"""
Invert the element.
@@ -315,14 +315,14 @@ class Mod:
"""
raise NotImplementedError
- def __invert__(self) -> "Mod":
+ def __invert__(self: M) -> M:
return self.inverse()
- def is_residue(self) -> bool:
+ def is_residue(self: M) -> bool:
"""Whether this element is a quadratic residue (only implemented for prime modulus)."""
raise NotImplementedError
- def sqrt(self) -> "Mod":
+ def sqrt(self: M) -> M:
"""
Compute the modular square root of this element (only implemented for prime modulus).
@@ -330,13 +330,13 @@ class Mod:
"""
raise NotImplementedError
- def is_cubic_residue(self) -> bool:
+ def is_cubic_residue(self: M) -> bool:
"""
Whether this element is a cubic residue (only implemented for prime modulus).
"""
raise NotImplementedError
- def cube_root(self) -> "Mod":
+ def cube_root(self: M) -> M:
"""
Compute the cube root of this element (only implemented for prime modulus).
@@ -345,33 +345,33 @@ class Mod:
raise NotImplementedError
@_check
- def __mul__(self, other) -> "Mod":
+ def __mul__(self: M, other) -> M:
return self.__class__((self.x * other.x) % self.n, self.n)
@_check
- def __rmul__(self, other) -> "Mod":
+ def __rmul__(self: M, other) -> M:
return self * other
@_check
- def __truediv__(self, other) -> "Mod":
+ def __truediv__(self: M, other) -> M:
return self * ~other
@_check
- def __rtruediv__(self, other) -> "Mod":
+ def __rtruediv__(self: M, other) -> M:
return ~self * other
@_check
- def __floordiv__(self, other) -> "Mod":
+ def __floordiv__(self: M, other) -> M:
return self * ~other
@_check
- def __rfloordiv__(self, other) -> "Mod":
+ def __rfloordiv__(self: M, other) -> M:
return ~self * other
- def __bytes__(self) -> bytes:
+ def __bytes__(self: M) -> bytes:
raise NotImplementedError
- def __int__(self) -> int:
+ def __int__(self: M) -> int:
raise NotImplementedError
@classmethod
@@ -385,18 +385,18 @@ class Mod:
with RandomModAction(n) as action:
return action.exit(mod(secrets.randbelow(n), n))
- def __pow__(self, n, _=None) -> "Mod":
+ def __pow__(self: M, n, _=None) -> M:
return NotImplemented
- def __str__(self):
+ def __str__(self: M):
return str(self.x)
- def __format__(self, format_spec):
+ def __format__(self: M, format_spec):
return format(int(self), format_spec)
@public
-class Undefined(Mod):
+class Undefined(Mod["Undefined"]):
"""A special undefined element."""
__slots__ = ("x", "n")
@@ -432,6 +432,12 @@ class Undefined(Mod):
def is_residue(self):
raise NotImplementedError
+ def cube_root(self):
+ raise NotImplementedError
+
+ def is_cubic_residue(self):
+ raise NotImplementedError
+
def __invert__(self):
raise NotImplementedError
diff --git a/pyecsca/ec/mod/flint.py b/pyecsca/ec/mod/flint.py
index 69e29f4..890d727 100644
--- a/pyecsca/ec/mod/flint.py
+++ b/pyecsca/ec/mod/flint.py
@@ -50,7 +50,7 @@ if has_flint:
return method
@public
- class FlintMod(Mod):
+ class FlintMod(Mod["FlintMod"]):
"""An element x of ℤₙ. Implemented by flint."""
x: flint.fmpz_mod
diff --git a/pyecsca/ec/mod/gmp.py b/pyecsca/ec/mod/gmp.py
index 83f0aaf..268a752 100644
--- a/pyecsca/ec/mod/gmp.py
+++ b/pyecsca/ec/mod/gmp.py
@@ -37,7 +37,7 @@ if has_gmp:
return gmpy2.is_prime(x)
@public
- class GMPMod(Mod):
+ class GMPMod(Mod["GMPMod"]):
"""An element x of ℤₙ. Implemented by GMP."""
x: gmpy2.mpz
diff --git a/pyecsca/ec/mod/raw.py b/pyecsca/ec/mod/raw.py
index a70c627..dd5e4f8 100644
--- a/pyecsca/ec/mod/raw.py
+++ b/pyecsca/ec/mod/raw.py
@@ -8,7 +8,7 @@ from pyecsca.ec.mod.base import Mod, extgcd, miller_rabin, jacobi, cube_root_inn
@public
-class RawMod(Mod):
+class RawMod(Mod["RawMod"]):
"""An element x of ℤₙ (implemented using Python integers)."""
x: int