diff options
| author | J08nY | 2023-02-12 22:45:16 +0100 |
|---|---|---|
| committer | J08nY | 2023-02-12 22:45:16 +0100 |
| commit | a08a052db35e9b940b33b57750c5addf0f66facd (patch) | |
| tree | df67e3e3e85fa11db8a18e7ab56c31b521b84a7b | |
| parent | abd075a326ced7648e997d9ac7343b054f67962f (diff) | |
| download | pyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.tar.gz pyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.tar.zst pyecsca-a08a052db35e9b940b33b57750c5addf0f66facd.zip | |
Skip unnecessary modular reductions and casts in GMPMod.
| -rw-r--r-- | pyecsca/ec/mod.py | 53 |
1 files changed, 35 insertions, 18 deletions
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py index ebeea57..405ab19 100644 --- a/pyecsca/ec/mod.py +++ b/pyecsca/ec/mod.py @@ -584,23 +584,25 @@ if has_gmp: def __new__(cls, *args, **kwargs): return object.__new__(cls) - def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz]): - self.n = gmpy2.mpz(n) if not type(n) is gmpy2.mpz else n - self.x = gmpy2.mpz(x % self.n) if not type(x) is gmpy2.mpz else x % self.n - # self.x = gmpy2.mpz(x % n) - # self.n = gmpy2.mpz(n) + def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz], ensure: bool = True): + if ensure: + self.n = gmpy2.mpz(n) + self.x = gmpy2.mpz(x % self.n) + else: + self.n = n + self.x = x def inverse(self) -> "GMPMod": if self.x == 0: raise_non_invertible() if self.x == 1: - return GMPMod(1, self.n) + return GMPMod(gmpy2.mpz(1), self.n, ensure=False) try: res = gmpy2.invert(self.x, self.n) except ZeroDivisionError: raise_non_invertible() - res = 0 - return GMPMod(res, self.n) + res = gmpy2.mpz(0) + return GMPMod(res, self.n, ensure=False) def is_residue(self) -> bool: if not _is_prime(self.n): @@ -615,7 +617,7 @@ if has_gmp: if not _is_prime(self.n): raise NotImplementedError if self.x == 0: - return GMPMod(0, self.n) + return GMPMod(gmpy2.mpz(0), self.n, ensure=False) if not self.is_residue(): raise_non_residue() if self.n % 4 == 3: @@ -626,12 +628,12 @@ if has_gmp: q //= 2 s += 1 - z = 2 - while GMPMod(z, self.n).is_residue(): + z = gmpy2.mpz(2) + while GMPMod(z, self.n, ensure=False).is_residue(): z += 1 m = s - c = GMPMod(z, self.n) ** int(q) + c = GMPMod(z, self.n, ensure=False) ** int(q) t = self ** int(q) r_exp = (q + 1) // 2 r = self ** int(r_exp) @@ -641,17 +643,32 @@ if has_gmp: while not (t ** (2 ** i)) == 1: i += 1 two_exp = m - (i + 1) - b = c ** int(GMPMod(2, self.n) ** two_exp) - m = int(GMPMod(i, self.n)) + b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp) + m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False)) c = b ** 2 t *= c r *= b return r @_check + def __add__(self, other) -> "GMPMod": + return GMPMod((self.x + other.x) % self.n, self.n, ensure=False) + + @_check + def __sub__(self, other) -> "GMPMod": + return GMPMod((self.x - other.x) % self.n, self.n, ensure=False) + + def __neg__(self) -> "GMPMod": + return GMPMod(self.n - self.x, self.n, ensure=False) + + @_check + def __mul__(self, other) -> "GMPMod": + return GMPMod((self.x * other.x) % self.n, self.n, ensure=False) + + @_check def __divmod__(self, divisor) -> Tuple["GMPMod", "GMPMod"]: q, r = gmpy2.f_divmod(self.x, divisor.x) - return GMPMod(q, self.n), GMPMod(r, self.n) + return GMPMod(q, self.n, ensure=False), GMPMod(r, self.n, ensure=False) def __bytes__(self): return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big") @@ -679,11 +696,11 @@ if has_gmp: if type(n) not in (int, gmpy2.mpz): raise TypeError if n == 0: - return GMPMod(1, self.n) + return GMPMod(gmpy2.mpz(1), self.n, ensure=False) if n < 0: return self.inverse() ** (-n) if n == 1: - return GMPMod(self.x, self.n) - return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n) + return GMPMod(self.x, self.n, ensure=False) + return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n, ensure=False) _mod_classes["gmp"] = GMPMod |
