aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJ08nY2025-03-29 18:52:30 +0100
committerJ08nY2025-03-29 18:52:30 +0100
commit87ceaa26009266a429352aa11bac586cb4bf5132 (patch)
tree385d2a41ec977affc3be0e487ac34439366f8414
parent0ac91af59627a7ce2ad3eca8f26c263e9e275c88 (diff)
downloadpyecsca-87ceaa26009266a429352aa11bac586cb4bf5132.tar.gz
pyecsca-87ceaa26009266a429352aa11bac586cb4bf5132.tar.zst
pyecsca-87ceaa26009266a429352aa11bac586cb4bf5132.zip
-rw-r--r--pyecsca/ec/mod/base.py89
-rw-r--r--pyecsca/ec/mod/flint.py56
-rw-r--r--pyecsca/ec/mod/gmp.py52
-rw-r--r--pyecsca/ec/mod/raw.py70
4 files changed, 136 insertions, 131 deletions
diff --git a/pyecsca/ec/mod/base.py b/pyecsca/ec/mod/base.py
index 7a275c0..39243c0 100644
--- a/pyecsca/ec/mod/base.py
+++ b/pyecsca/ec/mod/base.py
@@ -3,12 +3,15 @@ import secrets
from functools import lru_cache, wraps
from public import public
-from typing import Tuple, Any, Dict, Type, Set
+from typing import Tuple, Any, Dict, Type, Set, TypeVar
from pyecsca.ec.context import ResultAction
from pyecsca.misc.cfg import getconfig
+M = TypeVar("M", bound="Mod")
+
+
@public
def gcd(a: int, b: int) -> int:
"""Euclid's greatest common denominator algorithm."""
@@ -65,8 +68,9 @@ def jacobi(x: int, n: int) -> int:
@public
-def square_roots(x: "Mod") -> Set["Mod"]:
+def square_roots(x: M) -> Set[M]:
"""
+ Compute all square roots of x.
:param x:
:return:
@@ -74,12 +78,13 @@ def square_roots(x: "Mod") -> Set["Mod"]:
if not x.is_residue():
return set()
sqrt = x.sqrt()
- return {sqrt, -sqrt}
+ return {sqrt, -sqrt} # type: ignore
@public
-def cube_roots(x: "Mod") -> Set["Mod"]:
+def cube_roots(x: M) -> Set[M]:
"""
+ Compute all cube roots of x.
:param x:
:return:
@@ -89,7 +94,7 @@ def cube_roots(x: "Mod") -> Set["Mod"]:
cube_root = x.cube_root()
if (x.n - 1) % 3 != 0:
# gcd(3, p - 1) = 1
- return {cube_root}
+ return {cube_root} # type: ignore
else:
# gcd(3, p - 1) = 3
m = (x.n - 1) // 3
@@ -99,7 +104,79 @@ def cube_roots(x: "Mod") -> Set["Mod"]:
r = z ** m
if r != 1:
break
- return {cube_root, cube_root * r, cube_root * (r ** 2)}
+ return {cube_root, cube_root * r, cube_root * (r ** 2)} # type: ignore
+
+
+def square_root_inner(x: M, intwrap, mod_class) -> M:
+ if x.n % 4 == 3:
+ return x ** int((x.n + 1) // 4) # type: ignore
+ q = x.n - 1
+ s = 0
+ while q % 2 == 0:
+ q //= 2
+ s += 1
+
+ z = intwrap(2)
+ while mod_class(z, x.n).is_residue():
+ z += 1
+
+ m = s
+ c = mod_class(z, x.n) ** q
+ t = x ** q
+ r_exp = (q + 1) // 2
+ r = x ** r_exp
+
+ while t != 1:
+ i = 1
+ while not (t ** (2 ** i)) == 1:
+ i += 1
+ two_exp = m - (i + 1)
+ b = c ** int(mod_class(2, x.n) ** two_exp)
+ m = int(mod_class(i, x.n))
+ c = b ** 2
+ t *= c
+ r *= b
+ return r
+
+
+def cube_root_inner(x: M, intwrap, mod_class) -> M:
+ if x.n % 3 == 2:
+ inv3 = mod_class(intwrap(3), x.n - 1).inverse()
+ return x ** int(inv3) # type: ignore
+ q = x.n - 1
+ s = 0
+ while q % 3 == 0:
+ q //= 3
+ s += 1
+ t = q
+ if t % 3 == 1:
+ k = (t - 1) // 3
+ else:
+ k = (t + 1) // 3
+
+ b = intwrap(2)
+ while mod_class(b, x.n).is_cubic_residue():
+ b += 1
+
+ c = mod_class(b, x.n) ** t
+ r = x ** t
+ h = mod_class(intwrap(1), x.n)
+ cp = c ** (3 ** (s - 1))
+ c = c.inverse()
+ for i in range(1, s):
+ d = r ** (3 ** (s - i - 1))
+ if d == cp:
+ h *= c
+ r *= c ** 3
+ elif d != 1:
+ h *= c ** 2
+ r *= c ** 6
+ c = c ** 3
+ x = (x ** k) * h
+ if t % 3 == 1:
+ return x.inverse() # type: ignore
+ else:
+ return x
@public
diff --git a/pyecsca/ec/mod/flint.py b/pyecsca/ec/mod/flint.py
index 6a64eeb..3436953 100644
--- a/pyecsca/ec/mod/flint.py
+++ b/pyecsca/ec/mod/flint.py
@@ -1,5 +1,5 @@
import warnings
-from functools import lru_cache, wraps
+from functools import lru_cache, wraps, partial
from typing import Union
from public import public
@@ -10,7 +10,7 @@ from pyecsca.ec.error import (
NonResidueError,
NonResidueWarning,
)
-from pyecsca.ec.mod.base import Mod
+from pyecsca.ec.mod.base import Mod, cube_root_inner, square_root_inner
has_flint = False
try:
@@ -109,35 +109,33 @@ if has_flint:
except (ValueError, DomainError):
raise_non_residue()
- if mod % 4 == 3:
- return self ** int((mod + 1) // 4)
- q = mod - 1
- s = 0
- while q % 2 == 0:
- q //= 2
- s += 1
-
- z = self._ctx(2)
- while FlintMod(z, self._ctx, ensure=False).is_residue():
- z += 1
+ if self.x == 0:
+ return FlintMod(self._ctx(0), self._ctx, ensure=False)
+ if not self.is_residue():
+ raise_non_residue()
+ return square_root_inner(self, self._ctx, lambda x: FlintMod(x, self._ctx, ensure=False))
- m = s
- c = FlintMod(z, self._ctx, ensure=False) ** int(q)
- t = self ** int(q)
- r_exp = (q + 1) // 2
- r = self ** int(r_exp)
+ def is_cubic_residue(self) -> bool:
+ if not _fmpz_is_prime(self.n):
+ raise NotImplementedError
+ if self.x in (0, 1):
+ return True
+ if self.n % 3 == 2:
+ return True
+ pm1 = self.n - 1
+ r = self ** (pm1 // 3)
+ return r == 1
- while t != 1:
- i = 1
- while not (t ** (2**i)) == 1:
- i += 1
- two_exp = m - (i + 1)
- b = c ** int(FlintMod(self._ctx(2), self._ctx, ensure=False) ** two_exp)
- m = int(FlintMod(self._ctx(i), self._ctx, ensure=False))
- c = b**2
- t *= c
- r *= b
- return r
+ def cube_root(self) -> "FlintMod":
+ if not _fmpz_is_prime(self.n):
+ raise NotImplementedError
+ if self.x == 0:
+ return FlintMod(self._ctx(0), self._ctx, ensure=False)
+ if self.x == 1:
+ return FlintMod(self._ctx(1), self._ctx, ensure=False)
+ if not self.is_cubic_residue():
+ raise_non_residue()
+ return cube_root_inner(self, self._ctx, lambda x: FlintMod(x, self._ctx, ensure=False))
@_flint_check
def __add__(self, other) -> "FlintMod":
diff --git a/pyecsca/ec/mod/gmp.py b/pyecsca/ec/mod/gmp.py
index 87301c0..a8b589e 100644
--- a/pyecsca/ec/mod/gmp.py
+++ b/pyecsca/ec/mod/gmp.py
@@ -1,9 +1,9 @@
-from functools import lru_cache, wraps
+from functools import lru_cache, wraps, partial
from typing import Union
from public import public
-from pyecsca.ec.mod.base import Mod
+from pyecsca.ec.mod.base import Mod, cube_root_inner, square_root_inner
from pyecsca.ec.error import (
raise_non_invertible,
raise_non_residue,
@@ -88,35 +88,29 @@ if has_gmp:
return GMPMod(gmpy2.mpz(0), self.n, ensure=False)
if not self.is_residue():
raise_non_residue()
- if self.n % 4 == 3:
- return self ** int((self.n + 1) // 4)
- q = self.n - 1
- s = 0
- while q % 2 == 0:
- q //= 2
- s += 1
+ return square_root_inner(self, gmpy2.mpz, partial(GMPMod, ensure=False))
- z = gmpy2.mpz(2)
- while GMPMod(z, self.n, ensure=False).is_residue():
- z += 1
-
- m = s
- c = GMPMod(z, self.n, ensure=False) ** int(q)
- t = self ** int(q)
- r_exp = (q + 1) // 2
- r = self ** int(r_exp)
+ def is_cubic_residue(self) -> bool:
+ if not _gmpy_is_prime(self.n):
+ raise NotImplementedError
+ if self.x in (0, 1):
+ return True
+ if self.n % 3 == 2:
+ return True
+ pm1 = self.n - 1
+ r = self ** (pm1 // 3)
+ return r == 1
- while t != 1:
- i = 1
- while not (t ** (2**i)) == 1:
- i += 1
- two_exp = m - (i + 1)
- b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp)
- m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False))
- c = b**2
- t *= c
- r *= b
- return r
+ def cube_root(self) -> "GMPMod":
+ if not _gmpy_is_prime(self.n):
+ raise NotImplementedError
+ if self.x == 0:
+ return GMPMod(gmpy2.mpz(0), self.n, ensure=False)
+ if self.x == 1:
+ return GMPMod(gmpy2.mpz(1), self.n, ensure=False)
+ if not self.is_cubic_residue():
+ raise_non_residue()
+ return cube_root_inner(self, gmpy2.mpz, partial(GMPMod, ensure=False))
@_check
def __add__(self, other) -> "GMPMod":
diff --git a/pyecsca/ec/mod/raw.py b/pyecsca/ec/mod/raw.py
index 0d76594..1e1aa45 100644
--- a/pyecsca/ec/mod/raw.py
+++ b/pyecsca/ec/mod/raw.py
@@ -4,7 +4,7 @@ from pyecsca.ec.error import (
raise_non_residue,
)
-from pyecsca.ec.mod.base import Mod, extgcd, miller_rabin, jacobi
+from pyecsca.ec.mod.base import Mod, extgcd, miller_rabin, jacobi, cube_root_inner, square_root_inner
@public
@@ -47,35 +47,7 @@ class RawMod(Mod):
return RawMod(0, self.n)
if not self.is_residue():
raise_non_residue()
- if self.n % 4 == 3:
- return self ** int((self.n + 1) // 4)
- q = self.n - 1
- s = 0
- while q % 2 == 0:
- q //= 2
- s += 1
-
- z = 2
- while RawMod(z, self.n).is_residue():
- z += 1
-
- m = s
- c = RawMod(z, self.n) ** q
- t = self**q
- r_exp = (q + 1) // 2
- r = self**r_exp
-
- while t != 1:
- i = 1
- while not (t ** (2**i)) == 1:
- i += 1
- two_exp = m - (i + 1)
- b = c ** int(RawMod(2, self.n) ** two_exp)
- m = int(RawMod(i, self.n))
- c = b**2
- t *= c
- r *= b
- return r
+ return square_root_inner(self, int, RawMod)
def is_cubic_residue(self):
if not miller_rabin(self.n):
@@ -97,43 +69,7 @@ class RawMod(Mod):
return RawMod(1, self.n)
if not self.is_cubic_residue():
raise_non_residue()
- if self.n % 3 == 2:
- inv3 = RawMod(3, self.n - 1).inverse()
- return self ** int(inv3)
- q = self.n - 1
- s = 0
- while q % 3 == 0:
- q //= 3
- s += 1
- t = q
- if t % 3 == 1:
- k = (t - 1) // 3
- else:
- k = (t + 1) // 3
-
- b = 2
- while RawMod(b, self.n).is_cubic_residue():
- b += 1
-
- c = RawMod(b, self.n) ** t
- r = self ** t
- h = RawMod(1, self.n)
- cp = c ** (3 ** (s - 1))
- c = c.inverse()
- for i in range(1, s):
- d = r ** (3 ** (s - i - 1))
- if d == cp:
- h *= c
- r *= c ** 3
- elif d != 1:
- h *= c ** 2
- r *= c ** 6
- c = c ** 3
- x: RawMod = (self ** k) * h
- if t % 3 == 1:
- return x.inverse()
- else:
- return x
+ return cube_root_inner(self, int, RawMod)
def __bytes__(self):
return self.x.to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")