aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2023-02-12 22:45:16 +0100
committerJ08nY2023-02-12 22:45:16 +0100
commita08a052db35e9b940b33b57750c5addf0f66facd (patch)
treedf67e3e3e85fa11db8a18e7ab56c31b521b84a7b
parentabd075a326ced7648e997d9ac7343b054f67962f (diff)
downloadpyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.tar.gz
pyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.tar.zst
pyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.zip
Skip unnecessary modular reductions and casts in GMPMod.
-rw-r--r--pyecsca/ec/mod.py53
1 files changed, 35 insertions, 18 deletions
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index ebeea57..405ab19 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -584,23 +584,25 @@ if has_gmp:
def __new__(cls, *args, **kwargs):
return object.__new__(cls)
- def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz]):
- self.n = gmpy2.mpz(n) if not type(n) is gmpy2.mpz else n
- self.x = gmpy2.mpz(x % self.n) if not type(x) is gmpy2.mpz else x % self.n
- # self.x = gmpy2.mpz(x % n)
- # self.n = gmpy2.mpz(n)
+ def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz], ensure: bool = True):
+ if ensure:
+ self.n = gmpy2.mpz(n)
+ self.x = gmpy2.mpz(x % self.n)
+ else:
+ self.n = n
+ self.x = x
def inverse(self) -> "GMPMod":
if self.x == 0:
raise_non_invertible()
if self.x == 1:
- return GMPMod(1, self.n)
+ return GMPMod(gmpy2.mpz(1), self.n, ensure=False)
try:
res = gmpy2.invert(self.x, self.n)
except ZeroDivisionError:
raise_non_invertible()
- res = 0
- return GMPMod(res, self.n)
+ res = gmpy2.mpz(0)
+ return GMPMod(res, self.n, ensure=False)
def is_residue(self) -> bool:
if not _is_prime(self.n):
@@ -615,7 +617,7 @@ if has_gmp:
if not _is_prime(self.n):
raise NotImplementedError
if self.x == 0:
- return GMPMod(0, self.n)
+ return GMPMod(gmpy2.mpz(0), self.n, ensure=False)
if not self.is_residue():
raise_non_residue()
if self.n % 4 == 3:
@@ -626,12 +628,12 @@ if has_gmp:
q //= 2
s += 1
- z = 2
- while GMPMod(z, self.n).is_residue():
+ z = gmpy2.mpz(2)
+ while GMPMod(z, self.n, ensure=False).is_residue():
z += 1
m = s
- c = GMPMod(z, self.n) ** int(q)
+ c = GMPMod(z, self.n, ensure=False) ** int(q)
t = self ** int(q)
r_exp = (q + 1) // 2
r = self ** int(r_exp)
@@ -641,17 +643,32 @@ if has_gmp:
while not (t ** (2 ** i)) == 1:
i += 1
two_exp = m - (i + 1)
- b = c ** int(GMPMod(2, self.n) ** two_exp)
- m = int(GMPMod(i, self.n))
+ b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp)
+ m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False))
c = b ** 2
t *= c
r *= b
return r
@_check
+ def __add__(self, other) -> "GMPMod":
+ return GMPMod((self.x + other.x) % self.n, self.n, ensure=False)
+
+ @_check
+ def __sub__(self, other) -> "GMPMod":
+ return GMPMod((self.x - other.x) % self.n, self.n, ensure=False)
+
+ def __neg__(self) -> "GMPMod":
+ return GMPMod(self.n - self.x, self.n, ensure=False)
+
+ @_check
+ def __mul__(self, other) -> "GMPMod":
+ return GMPMod((self.x * other.x) % self.n, self.n, ensure=False)
+
+ @_check
def __divmod__(self, divisor) -> Tuple["GMPMod", "GMPMod"]:
q, r = gmpy2.f_divmod(self.x, divisor.x)
- return GMPMod(q, self.n), GMPMod(r, self.n)
+ return GMPMod(q, self.n, ensure=False), GMPMod(r, self.n, ensure=False)
def __bytes__(self):
return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")
@@ -679,11 +696,11 @@ if has_gmp:
if type(n) not in (int, gmpy2.mpz):
raise TypeError
if n == 0:
- return GMPMod(1, self.n)
+ return GMPMod(gmpy2.mpz(1), self.n, ensure=False)
if n < 0:
return self.inverse() ** (-n)
if n == 1:
- return GMPMod(self.x, self.n)
- return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n)
+ return GMPMod(self.x, self.n, ensure=False)
+ return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n, ensure=False)
_mod_classes["gmp"] = GMPMod