aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/ec
diff options
context:
space:
mode:
authorJ08nY2020-12-10 00:02:15 +0100
committerJ08nY2020-12-10 00:02:15 +0100
commit0bbc82710badf00431d160cb1785f90c2d2aa99d (patch)
treee62aa6186b2858a51c21da9555215e8bebe73497 /pyecsca/ec
parentf6fb6e452d39fb87b1b690460fb9011566119f69 (diff)
downloadpyecsca-0bbc82710badf00431d160cb1785f90c2d2aa99d.tar.gz
pyecsca-0bbc82710badf00431d160cb1785f90c2d2aa99d.tar.zst
pyecsca-0bbc82710badf00431d160cb1785f90c2d2aa99d.zip
Add support for GMP modular arithmetic.
Diffstat (limited to 'pyecsca/ec')
-rw-r--r--pyecsca/ec/mod.py304
1 files changed, 226 insertions, 78 deletions
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index fa5f866..4e69790 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -1,6 +1,15 @@
import random
import secrets
from functools import wraps, lru_cache
+from abc import ABC, abstractmethod
+
+has_gmp = False
+try:
+ import gmpy2
+
+ has_gmp = True
+except ImportError:
+ pass
from public import public
@@ -82,6 +91,16 @@ def check(func):
@public
+class NonInvertibleError(ArithmeticError):
+ pass
+
+
+@public
+class NonResidueError(ArithmeticError):
+ pass
+
+
+@public
class RandomModAction(ResultAction):
"""A random sampling from Z_n."""
order: int
@@ -94,17 +113,14 @@ class RandomModAction(ResultAction):
return f"{self.__class__.__name__}({self.order:x})"
-@public
-class Mod(object):
- """An element x of ℤₙ."""
-
- def __init__(self, x: int, n: int):
- self.x: int = x % n
- self.n: int = n
+class BaseMod(ABC):
+ def __init__(self, x, n):
+ self.x = x
+ self.n = n
@check
def __add__(self, other):
- return Mod((self.x + other.x) % self.n, self.n)
+ return self.__class__((self.x + other.x) % self.n, self.n)
@check
def __radd__(self, other):
@@ -112,22 +128,81 @@ class Mod(object):
@check
def __sub__(self, other):
- return Mod((self.x - other.x) % self.n, self.n)
+ return self.__class__((self.x - other.x) % self.n, self.n)
@check
def __rsub__(self, other):
return -self + other
def __neg__(self):
- return Mod(self.n - self.x, self.n)
+ return self.__class__(self.n - self.x, self.n)
+ @abstractmethod
def inverse(self):
- x, y, d = extgcd(self.x, self.n)
- return Mod(x, self.n)
+ ...
def __invert__(self):
return self.inverse()
+ @check
+ def __mul__(self, other):
+ return self.__class__((self.x * other.x) % self.n, self.n)
+
+ @check
+ def __rmul__(self, other):
+ return self * other
+
+ @check
+ def __truediv__(self, other):
+ return self * ~other
+
+ @check
+ def __rtruediv__(self, other):
+ return ~self * other
+
+ @check
+ def __floordiv__(self, other):
+ return self * ~other
+
+ @check
+ def __rfloordiv__(self, other):
+ return ~self * other
+
+ @check
+ def __div__(self, other):
+ return self.__floordiv__(other)
+
+ @check
+ def __rdiv__(self, other):
+ return self.__rfloordiv__(other)
+
+ @check
+ def __divmod__(self, divisor):
+ q, r = divmod(self.x, divisor.x)
+ return self.__class__(q, self.n), self.__class__(r, self.n)
+
+ @classmethod
+ def random(cls, n: int):
+ with RandomModAction(n) as action:
+ return action.exit(cls(secrets.randbelow(n), n))
+
+
+class RawMod(BaseMod):
+ """An element x of ℤₙ."""
+ x: int
+ n: int
+
+ def __init__(self, x: int, n: int):
+ super().__init__(x % n, n)
+
+ def inverse(self):
+ if self.x == 0:
+ raise NonInvertibleError("Inverting zero.")
+ x, y, d = extgcd(self.x, self.n)
+ if d != 1:
+ raise NonInvertibleError("Element not invertible.")
+ return RawMod(x, self.n)
+
def is_residue(self):
"""Whether this element is a quadratic residue (only implemented for prime modulus)."""
if not miller_rabin(self.n):
@@ -136,8 +211,8 @@ class Mod(object):
return True
if self.n == 2:
return self.x in (0, 1)
- legendre = self ** ((self.n - 1) // 2)
- return legendre == 1
+ legendre_symbol = self ** ((self.n - 1) // 2)
+ return legendre_symbol == 1
def sqrt(self):
"""
@@ -147,6 +222,13 @@ class Mod(object):
"""
if not miller_rabin(self.n):
raise NotImplementedError
+ if not self.is_residue():
+ if self.x == 0:
+ return RawMod(0, self.n)
+ else:
+ raise NonResidueError("No square root exists.")
+ if self.n % 4 == 3:
+ return self ** int((self.n + 1) // 4)
q = self.n - 1
s = 0
while q % 2 == 0:
@@ -154,79 +236,37 @@ class Mod(object):
s += 1
z = 2
- while Mod(z, self.n).is_residue():
+ while RawMod(z, self.n).is_residue():
z += 1
m = s
- c = Mod(z, self.n) ** q
+ c = RawMod(z, self.n) ** q
t = self ** q
r_exp = (q + 1) // 2
r = self ** r_exp
while t != 1:
i = 1
- while not (t ** (2**i)) == 1:
+ while not (t ** (2 ** i)) == 1:
i += 1
two_exp = m - (i + 1)
- b = c ** int(Mod(2, self.n)**two_exp)
- m = int(Mod(i, self.n))
+ b = c ** int(RawMod(2, self.n) ** two_exp)
+ m = int(RawMod(i, self.n))
c = b ** 2
t *= c
r *= b
return r
- @check
- def __mul__(self, other):
- return Mod((self.x * other.x) % self.n, self.n)
-
- @check
- def __rmul__(self, other):
- return self * other
-
- @check
- def __truediv__(self, other):
- return self * ~other
-
- @check
- def __rtruediv__(self, other):
- return ~self * other
-
- @check
- def __floordiv__(self, other):
- return self * ~other
-
- @check
- def __rfloordiv__(self, other):
- return ~self * other
-
- @check
- def __div__(self, other):
- return self.__floordiv__(other)
-
- @check
- def __rdiv__(self, other):
- return self.__rfloordiv__(other)
-
- @check
- def __divmod__(self, divisor):
- q, r = divmod(self.x, divisor.x)
- return Mod(q, self.n), Mod(r, self.n)
-
def __bytes__(self):
return self.x.to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")
- @staticmethod
- def random(n: int):
- with RandomModAction(n) as action:
- return action.exit(Mod(secrets.randbelow(n), n))
-
def __int__(self):
return self.x
def __eq__(self, other):
if type(other) is int:
return self.x == (other % self.n)
- if type(other) is not Mod:
+ if type(other) is not RawMod:
return False
return self.x == other.x and self.n == other.n
@@ -240,29 +280,19 @@ class Mod(object):
if type(n) is not int:
raise TypeError
if n == 0:
- return Mod(1, self.n)
+ return RawMod(1, self.n)
if n < 0:
- return self.inverse()**(-n)
+ return self.inverse() ** (-n)
if n == 1:
- return Mod(self.x, self.n)
-
- q = self
- r = self if n & 1 else Mod(1, self.n)
+ return RawMod(self.x, self.n)
- i = 2
- while i <= n:
- q = (q * q)
- if n & i == i:
- r = (q * r)
- i = i << 1
- return r
+ return RawMod(pow(self.x, n, self.n), self.n)
@public
-class Undefined(Mod):
-
+class Undefined(BaseMod):
def __init__(self):
- object.__init__(self)
+ super().__init__(None, None)
def __add__(self, other):
raise NotImplementedError
@@ -279,6 +309,9 @@ class Undefined(Mod):
def __neg__(self):
raise NotImplementedError
+ def inverse(self):
+ raise NotImplementedError
+
def __invert__(self):
raise NotImplementedError
@@ -326,3 +359,118 @@ class Undefined(Mod):
def __pow__(self, n):
raise NotImplementedError
+
+if has_gmp:
+
+ class GMPMod(BaseMod):
+ """An element x of ℤₙ. Implemented by GMP."""
+ x: gmpy2.mpz
+ n: gmpy2.mpz
+
+ def __init__(self, x: int, n: int):
+ super().__init__(gmpy2.mpz(x % n), gmpy2.mpz(n))
+
+ def inverse(self):
+ if self.x == 0:
+ raise NonInvertibleError("Inverting zero!")
+ if self.x == 1:
+ return GMPMod(1, self.n)
+ try:
+ res = gmpy2.invert(self.x, self.n)
+ except ZeroDivisionError:
+ raise NonInvertibleError("Element not invertible.")
+ return GMPMod(res, self.n)
+
+ def is_residue(self):
+ """Whether this element is a quadratic residue (only implemented for prime modulus)."""
+ if not gmpy2.is_prime(self.n):
+ raise NotImplementedError
+ if self.x == 0:
+ return True
+ if self.n == 2:
+ return self.x in (0, 1)
+ return gmpy2.legendre(self.x, self.n) == 1
+
+ def sqrt(self):
+ """
+ The modular square root of this element (only implemented for prime modulus).
+
+ Uses the `Tonelli-Shanks <https://en.wikipedia.org/wiki/Tonelli–Shanks_algorithm>`_ algorithm.
+ """
+ if not gmpy2.is_prime(self.n):
+ raise NotImplementedError
+ if not self.is_residue():
+ if self.x == 0:
+ return GMPMod(0, self.n)
+ else:
+ raise NonResidueError("No square root exists.")
+ if self.n % 4 == 3:
+ return self ** int((self.n + 1) // 4)
+ q = self.n - 1
+ s = 0
+ while q % 2 == 0:
+ q //= 2
+ s += 1
+
+ z = 2
+ while GMPMod(z, self.n).is_residue():
+ z += 1
+
+ m = s
+ c = GMPMod(z, self.n) ** int(q)
+ t = self ** int(q)
+ r_exp = (q + 1) // 2
+ r = self ** int(r_exp)
+
+ while t != 1:
+ i = 1
+ while not (t ** (2 ** i)) == 1:
+ i += 1
+ two_exp = m - (i + 1)
+ b = c ** int(GMPMod(2, self.n) ** two_exp)
+ m = int(GMPMod(i, self.n))
+ c = b ** 2
+ t *= c
+ r *= b
+ return r
+
+ @check
+ def __divmod__(self, divisor):
+ q, r = gmpy2.f_divmod(self.x, divisor.x)
+ return GMPMod(q, self.n), GMPMod(r, self.n)
+
+ def __bytes__(self):
+ return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")
+
+ def __int__(self):
+ return int(self.x)
+
+ def __eq__(self, other):
+ if type(other) is int:
+ return self.x == (gmpy2.mpz(other) % self.n)
+ if type(other) is not GMPMod:
+ return False
+ return self.x == other.x and self.n == other.n
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __repr__(self):
+ return str(int(self.x))
+
+ def __pow__(self, n):
+ if type(n) not in (int, gmpy2.mpz):
+ raise TypeError
+ if n == 0:
+ return GMPMod(1, self.n)
+ if n < 0:
+ return self.inverse() ** (-n)
+ if n == 1:
+ return GMPMod(self.x, self.n)
+ return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n)
+
+ Mod = GMPMod
+else:
+ Mod = RawMod
+
+public(Mod=Mod)