aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/mod.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec/mod.py')
-rw-r--r--pyecsca/ec/mod.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index cfe0748..cb9ef7e 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -8,7 +8,7 @@ dispatches to the implementation chosen by the runtime configuration of the libr
import random
import secrets
from functools import wraps, lru_cache
-from typing import Type, Dict
+from typing import Type, Dict, Any
from public import public
from sympy import Expr, FF
@@ -119,6 +119,8 @@ _mod_classes: Dict[str, Type] = {}
@public
class Mod(object):
"""An element x of ℤₙ."""
+ x: Any
+ n: Any
def __new__(cls, *args, **kwargs):
if cls != Mod:
@@ -131,10 +133,6 @@ class Mod(object):
selected_class = next(iter(_mod_classes.keys()))
return _mod_classes[selected_class].__new__(_mod_classes[selected_class], *args, **kwargs)
- def __init__(self, x, n):
- self.x = x
- self.n = n
-
@_check
def __add__(self, other):
return self.__class__((self.x + other.x) % self.n, self.n)
@@ -241,7 +239,8 @@ class RawMod(Mod):
return object.__new__(cls)
def __init__(self, x: int, n: int):
- super().__init__(x % n, n)
+ self.x = x % n
+ self.n = n
def inverse(self):
if self.x == 0:
@@ -344,7 +343,8 @@ class Undefined(Mod):
return object.__new__(cls)
def __init__(self):
- super().__init__(None, None)
+ self.x = None
+ self.n = None
def __add__(self, other):
raise NotImplementedError
@@ -447,7 +447,8 @@ class SymbolicMod(Mod):
return object.__new__(cls)
def __init__(self, x: Expr, n: int):
- super().__init__(x, n)
+ self.x = x
+ self.n = n
@_symbolic_check
def __add__(self, other):
@@ -539,6 +540,10 @@ class SymbolicMod(Mod):
if has_gmp:
+ @lru_cache
+ def _is_prime(x) -> bool:
+ return gmpy2.is_prime(x)
+
@public
class GMPMod(Mod):
"""An element x of ℤₙ. Implemented by GMP."""
@@ -549,7 +554,8 @@ if has_gmp:
return object.__new__(cls)
def __init__(self, x: int, n: int):
- super().__init__(gmpy2.mpz(x % n), gmpy2.mpz(n))
+ self.x = gmpy2.mpz(x % n)
+ self.n = gmpy2.mpz(n)
def inverse(self):
if self.x == 0:
@@ -565,7 +571,7 @@ if has_gmp:
def is_residue(self):
"""Whether this element is a quadratic residue (only implemented for prime modulus)."""
- if not gmpy2.is_prime(self.n):
+ if not _is_prime(self.n):
raise NotImplementedError
if self.x == 0:
return True
@@ -579,7 +585,7 @@ if has_gmp:
Uses the `Tonelli-Shanks <https://en.wikipedia.org/wiki/Tonelli–Shanks_algorithm>`_ algorithm.
"""
- if not gmpy2.is_prime(self.n):
+ if not _is_prime(self.n):
raise NotImplementedError
if self.x == 0:
return GMPMod(0, self.n)