aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/mod/base.py
diff options
context:
space:
mode:
authorJ08nY2025-03-31 09:31:25 +0200
committerJ08nY2025-03-31 09:31:25 +0200
commitebddd1ad1f031f653dff49e38b4d298f8b52bebd (patch)
tree64e983a07d9940bc6dffdd5434bc35be9ed5b60e /pyecsca/ec/mod/base.py
parentad342838982b0c061baa2525afa622189025b887 (diff)
downloadpyecsca-ebddd1ad1f031f653dff49e38b4d298f8b52bebd.tar.gz
pyecsca-ebddd1ad1f031f653dff49e38b4d298f8b52bebd.tar.zst
pyecsca-ebddd1ad1f031f653dff49e38b4d298f8b52bebd.zip
Diffstat (limited to 'pyecsca/ec/mod/base.py')
-rw-r--r--pyecsca/ec/mod/base.py58
1 files changed, 32 insertions, 26 deletions
diff --git a/pyecsca/ec/mod/base.py b/pyecsca/ec/mod/base.py
index 6555ba1..12020d6 100644
--- a/pyecsca/ec/mod/base.py
+++ b/pyecsca/ec/mod/base.py
@@ -3,7 +3,7 @@ import secrets
from functools import lru_cache, wraps
from public import public
-from typing import Tuple, Any, Dict, Type, Set, TypeVar
+from typing import Tuple, Any, Dict, Type, Set, TypeVar, Generic
from pyecsca.ec.context import ResultAction
from pyecsca.misc.cfg import getconfig
@@ -238,7 +238,7 @@ _mod_order = ["gmp", "flint", "python"]
@public
-class Mod:
+class Mod(Generic[M]):
"""
An element x of ℤₙ.
@@ -280,25 +280,25 @@ class Mod:
raise TypeError("Abstract")
@_check
- def __add__(self, other) -> "Mod":
+ def __add__(self: M, other) -> M:
return self.__class__((self.x + other.x) % self.n, self.n)
@_check
- def __radd__(self, other) -> "Mod":
+ def __radd__(self: M, other) -> M:
return self + other
@_check
- def __sub__(self, other) -> "Mod":
+ def __sub__(self: M, other) -> M:
return self.__class__((self.x - other.x) % self.n, self.n)
@_check
- def __rsub__(self, other) -> "Mod":
+ def __rsub__(self: M, other) -> M:
return -self + other
- def __neg__(self) -> "Mod":
+ def __neg__(self: M) -> M:
return self.__class__(self.n - self.x, self.n)
- def bit_length(self):
+ def bit_length(self: M) -> int:
"""
Compute the bit length of this element (in its positive integer representation).
@@ -306,7 +306,7 @@ class Mod:
"""
raise NotImplementedError
- def inverse(self) -> "Mod":
+ def inverse(self: M) -> M:
"""
Invert the element.
@@ -315,14 +315,14 @@ class Mod:
"""
raise NotImplementedError
- def __invert__(self) -> "Mod":
+ def __invert__(self: M) -> M:
return self.inverse()
- def is_residue(self) -> bool:
+ def is_residue(self: M) -> bool:
"""Whether this element is a quadratic residue (only implemented for prime modulus)."""
raise NotImplementedError
- def sqrt(self) -> "Mod":
+ def sqrt(self: M) -> M:
"""
Compute the modular square root of this element (only implemented for prime modulus).
@@ -330,13 +330,13 @@ class Mod:
"""
raise NotImplementedError
- def is_cubic_residue(self) -> bool:
+ def is_cubic_residue(self: M) -> bool:
"""
Whether this element is a cubic residue (only implemented for prime modulus).
"""
raise NotImplementedError
- def cube_root(self) -> "Mod":
+ def cube_root(self: M) -> M:
"""
Compute the cube root of this element (only implemented for prime modulus).
@@ -345,33 +345,33 @@ class Mod:
raise NotImplementedError
@_check
- def __mul__(self, other) -> "Mod":
+ def __mul__(self: M, other) -> M:
return self.__class__((self.x * other.x) % self.n, self.n)
@_check
- def __rmul__(self, other) -> "Mod":
+ def __rmul__(self: M, other) -> M:
return self * other
@_check
- def __truediv__(self, other) -> "Mod":
+ def __truediv__(self: M, other) -> M:
return self * ~other
@_check
- def __rtruediv__(self, other) -> "Mod":
+ def __rtruediv__(self: M, other) -> M:
return ~self * other
@_check
- def __floordiv__(self, other) -> "Mod":
+ def __floordiv__(self: M, other) -> M:
return self * ~other
@_check
- def __rfloordiv__(self, other) -> "Mod":
+ def __rfloordiv__(self: M, other) -> M:
return ~self * other
- def __bytes__(self) -> bytes:
+ def __bytes__(self: M) -> bytes:
raise NotImplementedError
- def __int__(self) -> int:
+ def __int__(self: M) -> int:
raise NotImplementedError
@classmethod
@@ -385,18 +385,18 @@ class Mod:
with RandomModAction(n) as action:
return action.exit(mod(secrets.randbelow(n), n))
- def __pow__(self, n, _=None) -> "Mod":
+ def __pow__(self: M, n, _=None) -> M:
return NotImplemented
- def __str__(self):
+ def __str__(self: M):
return str(self.x)
- def __format__(self, format_spec):
+ def __format__(self: M, format_spec):
return format(int(self), format_spec)
@public
-class Undefined(Mod):
+class Undefined(Mod["Undefined"]):
"""A special undefined element."""
__slots__ = ("x", "n")
@@ -432,6 +432,12 @@ class Undefined(Mod):
def is_residue(self):
raise NotImplementedError
+ def cube_root(self):
+ raise NotImplementedError
+
+ def is_cubic_residue(self):
+ raise NotImplementedError
+
def __invert__(self):
raise NotImplementedError