aboutsummaryrefslogtreecommitdiff
path: root/util/ec.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/ec.py')
-rw-r--r--util/ec.py1010
1 files changed, 1010 insertions, 0 deletions
diff --git a/util/ec.py b/util/ec.py
new file mode 100644
index 0000000..4e3244a
--- /dev/null
+++ b/util/ec.py
@@ -0,0 +1,1010 @@
+"""
+This module is a collection of modules glued together, to provide basic
+elliptic curve arithmetic for curves over prime and binary fields. It consists of
+ - tinyec: https://github.com/alexmgr/tinyec (GPL v3 licensed)
+ - pyfinite: https://github.com/emin63/pyfinite (MIT licensed)
+ - modular square root from https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python
+ - and some of my own code: https://github.com/J08nY
+"""
+
+import abc
+import random
+from functools import reduce, wraps
+from os import path
+
+
+def legendre_symbol(a, p):
+ """ Compute the Legendre symbol a|p using
+ Euler's criterion. p is a prime, a is
+ relatively prime to p (if p divides
+ a, then a|p = 0)
+
+ Returns 1 if a has a square root modulo
+ p, -1 otherwise.
+ """
+ ls = pow(a, (p - 1) // 2, p)
+ return -1 if ls == p - 1 else ls
+
+
+def is_prime(n, trials=50):
+ """
+ Miller-Rabin primality test.
+ """
+ s = 0
+ d = n - 1
+ while d % 2 == 0:
+ d >>= 1
+ s += 1
+ assert (2 ** s * d == n - 1)
+
+ def trial_composite(a):
+ if pow(a, d, n) == 1:
+ return False
+ for i in range(s):
+ if pow(a, 2 ** i * d, n) == n - 1:
+ return False
+ return True
+
+ for i in range(trials): # number of trials
+ a = random.randrange(2, n)
+ if trial_composite(a):
+ return False
+ return True
+
+
+def gcd(a, b):
+ """Euclid's greatest common denominator algorithm."""
+ if abs(a) < abs(b):
+ return gcd(b, a)
+
+ while abs(b) > 0:
+ q, r = divmod(a, b)
+ a, b = b, r
+
+ return a
+
+
+def extgcd(a, b):
+ """Extended Euclid's greatest common denominator algorithm."""
+ if abs(b) > abs(a):
+ (x, y, d) = extgcd(b, a)
+ return y, x, d
+
+ if abs(b) == 0:
+ return 1, 0, a
+
+ x1, x2, y1, y2 = 0, 1, 1, 0
+ while abs(b) > 0:
+ q, r = divmod(a, b)
+ x = x2 - q * x1
+ y = y2 - q * y1
+ a, b, x2, x1, y2, y1 = b, r, x1, x, y1, y
+
+ return x2, y2, a
+
+
+def check(func):
+ @wraps(func)
+ def method(self, other):
+ if isinstance(other, int):
+ other = self.__class__(other, self.field)
+ if type(self) is type(other):
+ if self.field == other.field:
+ return func(self, other)
+ else:
+ raise ValueError
+ else:
+ raise TypeError
+
+ return method
+
+
+class Mod(object):
+ """An element x of ℤₙ."""
+
+ def __init__(self, x: int, n: int):
+ self.x = x % n
+ self.field = n
+
+ @check
+ def __add__(self, other):
+ return Mod((self.x + other.x) % self.field, self.field)
+
+ @check
+ def __radd__(self, other):
+ return self + other
+
+ @check
+ def __sub__(self, other):
+ return Mod((self.x - other.x) % self.field, self.field)
+
+ @check
+ def __rsub__(self, other):
+ return -self + other
+
+ def __neg__(self):
+ return Mod(self.field - self.x, self.field)
+
+ def inverse(self):
+ x, y, d = extgcd(self.x, self.field)
+ return Mod(x, self.field)
+
+ def __invert__(self):
+ return self.inverse()
+
+ @check
+ def __mul__(self, other):
+ return Mod((self.x * other.x) % self.field, self.field)
+
+ @check
+ def __rmul__(self, other):
+ return self * other
+
+ @check
+ def __truediv__(self, other):
+ return self * ~other
+
+ @check
+ def __rtruediv__(self, other):
+ return ~self * other
+
+ @check
+ def __floordiv__(self, other):
+ return self * ~other
+
+ @check
+ def __rfloordiv__(self, other):
+ return ~self * other
+
+ @check
+ def __div__(self, other):
+ return self.__floordiv__(other)
+
+ @check
+ def __rdiv__(self, other):
+ return self.__rfloordiv__(other)
+
+ @check
+ def __divmod__(self, divisor):
+ q, r = divmod(self.x, divisor.x)
+ return Mod(q, self.field), Mod(r, self.field)
+
+ def __int__(self):
+ return self.x
+
+ def __eq__(self, other):
+ if type(other) is not Mod:
+ return False
+ return self.x == other.x and self.field == other.field
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __repr__(self):
+ return str(self.x)
+
+ def __pow__(self, n):
+ if not isinstance(n, int):
+ raise TypeError
+ if n == 0:
+ return Mod(1, self.field)
+ if n < 0:
+ return (~self) ** -n
+ if n == 1:
+ return self
+ if n == 2:
+ return self * self
+
+ q = self
+ r = self if n & 1 else Mod(1, self.field)
+
+ i = 2
+ while i <= n:
+ q = (q * q)
+ if n & i == i:
+ r = (q * r)
+ i = i << 1
+ return r
+
+ def sqrt(self):
+ if not is_prime(self.field):
+ raise NotImplementedError
+ # Simple cases
+ if legendre_symbol(self.x, self.field) != 1 or self.x == 0 or self.field == 2:
+ raise ValueError("Not a quadratic residue.")
+ if self.field % 4 == 3:
+ return self ** ((self.field + 1) // 4)
+
+ a = self.x
+ p = self.field
+ s = p - 1
+ e = 0
+ while s % 2 == 0:
+ s /= 2
+ e += 1
+
+ n = 2
+ while legendre_symbol(n, p) != -1:
+ n += 1
+
+ x = pow(a, (s + 1) / 2, p)
+ b = pow(a, s, p)
+ g = pow(n, s, p)
+ r = e
+
+ while True:
+ t = b
+ m = 0
+ for m in range(r):
+ if t == 1:
+ break
+ t = pow(t, 2, p)
+
+ if m == 0:
+ return Mod(x, p)
+
+ gs = pow(g, 2 ** (r - m - 1), p)
+ g = (gs * gs) % p
+ x = (x * gs) % p
+ b = (b * g) % p
+ r = m
+
+
+class FField(object):
+ """
+ The FField class implements a binary field.
+ """
+
+ def __init__(self, n, gen):
+ """
+ This method constructs the field GF(2^n). It takes two
+ required arguments, n and gen,
+ representing the coefficients of the generator polynomial
+ (of degree n) to use.
+ Note that you can look at the generator for the field object
+ F by looking at F.generator.
+ """
+
+ self.n = n
+ if len(gen) != n + 1:
+ full_gen = [0] * (n + 1)
+ for i in gen:
+ full_gen[i] = 1
+ gen = full_gen[::-1]
+ self.generator = self.to_element(gen)
+ self.unity = 1
+
+ def add(self, x, y):
+ """
+ Adds two field elements and returns the result.
+ """
+
+ return x ^ y
+
+ def subtract(self, x, y):
+ """
+ Subtracts the second argument from the first and returns
+ the result. In fields of characteristic two this is the same
+ as the Add method.
+ """
+ return x ^ y
+
+ def multiply(self, f, v):
+ """
+ Multiplies two field elements (modulo the generator
+ self.generator) and returns the result.
+
+ See MultiplyWithoutReducing if you don't want multiplication
+ modulo self.generator.
+ """
+ m = self.multiply_no_reduce(f, v)
+ return self.full_division(m, self.generator, self.find_degree(m), self.n)[1]
+
+ def inverse(self, f):
+ """
+ Computes the multiplicative inverse of its argument and
+ returns the result.
+ """
+ return self.ext_gcd(self.unity, f, self.generator, self.find_degree(f), self.n)[1]
+
+ def divide(self, f, v):
+ """
+ Divide(f,v) returns f * v^-1.
+ """
+ return self.multiply(f, self.inverse(v))
+
+ def exponentiate(self, f, n):
+ """
+ Exponentiate(f, n) returns f^n.
+ """
+ if not isinstance(n, int):
+ raise TypeError
+ if n == 0:
+ return self.unity
+ if n < 0:
+ f = self.inverse(f)
+ n = -n
+ if n == 1:
+ return f
+ if n == 2:
+ return self.multiply(f, f)
+
+ q = f
+ r = f if n & 1 else self.unity
+
+ i = 2
+ while i <= n:
+ q = self.multiply(q, q)
+ if n & i == i:
+ r = self.multiply(q, r)
+ i = i << 1
+ return r
+
+ def sqrt(self, f):
+ return self.exponentiate(f, (2 ** self.n) - 1)
+
+ def trace(self, f):
+ t = f
+ for _ in range(1, self.n):
+ t = self.add(self.multiply(t, t), f)
+ return t
+
+ def half_trace(self, f):
+ if self.n % 2 != 1:
+ raise ValueError
+ h = f
+ for _ in range(1, (self.n - 1) // 2):
+ h = self.multiply(h, h)
+ h = self.add(self.multiply(h, h), f)
+ return h
+
+ def find_degree(self, v):
+ """
+ Find the degree of the polynomial representing the input field
+ element v. This takes O(degree(v)) operations.
+
+ A faster version requiring only O(log(degree(v)))
+ could be written using binary search...
+ """
+ if v:
+ return v.bit_length() - 1
+ else:
+ return 0
+
+ def multiply_no_reduce(self, f, v):
+ """
+ Multiplies two field elements and does not take the result
+ modulo self.generator. You probably should not use this
+ unless you know what you are doing; look at Multiply instead.
+ """
+
+ result = 0
+ mask = self.unity
+ for i in range(self.n + 1):
+ if mask & v:
+ result = result ^ f
+ f = f << 1
+ mask = mask << 1
+ return result
+
+ def ext_gcd(self, d, a, b, a_degree, b_degree):
+ """
+ Takes arguments (d,a,b,aDegree,bDegree) where d = gcd(a,b)
+ and returns the result of the extended Euclid algorithm
+ on (d,a,b).
+ """
+ if b == 0:
+ return a, self.unity, 0
+ else:
+ (floorADivB, aModB) = self.full_division(a, b, a_degree, b_degree)
+ (d, x, y) = self.ext_gcd(d, b, aModB, b_degree, self.find_degree(aModB))
+ return d, y, self.subtract(x, self.multiply(floorADivB, y))
+
+ def full_division(self, f, v, f_degree, v_degree):
+ """
+ Takes four arguments, f, v, fDegree, and vDegree where
+ fDegree and vDegree are the degrees of the field elements
+ f and v represented as a polynomials.
+ This method returns the field elements a and b such that
+
+ f(x) = a(x) * v(x) + b(x).
+
+ That is, a is the divisor and b is the remainder, or in
+ other words a is like floor(f/v) and b is like f modulo v.
+ """
+
+ result = 0
+ mask = self.unity << f_degree
+ for i in range(f_degree, v_degree - 1, -1):
+ if mask & f:
+ result = result ^ (self.unity << (i - v_degree))
+ f = self.subtract(f, v << (i - v_degree))
+ mask = mask >> self.unity
+ return result, f
+
+ def coefficients(self, f):
+ """
+ Show coefficients of input field element represented as a
+ polynomial in decreasing order.
+ """
+
+ result = []
+ for i in range(self.n, -1, -1):
+ if (self.unity << i) & f:
+ result.append(1)
+ else:
+ result.append(0)
+
+ return result
+
+ def polynomial(self, f):
+ """
+ Show input field element represented as a polynomial.
+ """
+
+ f_degree = self.find_degree(f)
+ result = ''
+
+ if f == 0:
+ return '0'
+
+ for i in range(f_degree, 0, -1):
+ if (1 << i) & f:
+ result = result + (' x^' + repr(i))
+ if 1 & f:
+ result = result + ' ' + repr(1)
+ return result.strip().replace(' ', ' + ')
+
+ def to_element(self, l):
+ """
+ This method takes as input a binary list (e.g. [1, 0, 1, 1])
+ and converts it to a decimal representation of a field element.
+ For example, [1, 0, 1, 1] is mapped to 8 | 2 | 1 = 11.
+
+ Note if the input list is of degree >= to the degree of the
+ generator for the field, then you will have to call take the
+ result modulo the generator to get a proper element in the
+ field.
+ """
+
+ temp = map(lambda a, b: a << b, l, range(len(l) - 1, -1, -1))
+ return reduce(lambda a, b: a | b, temp)
+
+ def __str__(self):
+ return "F_(2^{}): {}".format(self.n, self.polynomial(self.generator))
+
+ def __repr__(self):
+ return str(self)
+
+
+class FElement(object):
+ """
+ This class provides field elements which overload the
+ +,-,*,%,//,/ operators to be the appropriate field operation.
+ Note that before creating FElement objects you must first
+ create an FField object.
+ """
+
+ def __init__(self, f, field):
+ """
+ The constructor takes two arguments, field, and e where
+ field is an FField object and e is an integer representing
+ an element in FField.
+
+ The result is a new FElement instance.
+ """
+ self.f = f
+ self.field = field
+
+ @check
+ def __add__(self, other):
+ return FElement(self.field.add(self.f, other.f), self.field)
+
+ @check
+ def __sub__(self, other):
+ return FElement(self.field.add(self.f, other.f), self.field)
+
+ def __neg__(self):
+ return self
+
+ @check
+ def __mul__(self, other):
+ return FElement(self.field.multiply(self.f, other.f), self.field)
+
+ @check
+ def __floordiv__(self, o):
+ return FElement(self.field.full_division(self.f, o.f,
+ self.field.find_degree(self.f),
+ self.field.find_degree(o.f))[0], self.field)
+
+ @check
+ def __truediv__(self, other):
+ return FElement(self.field.divide(self.f, other.f), self.field)
+
+ def __div__(self, *args, **kwargs):
+ return self.__truediv__(*args, **kwargs)
+
+ @check
+ def __divmod__(self, other):
+ d, m = self.field.full_division(self.f, other.f,
+ self.field.find_degree(self.f),
+ self.field.find_degree(other.f))
+ return FElement(d, self.field), FElement(m, self.field)
+
+ def inverse(self):
+ return FElement(self.field.inverse(self.f), self.field)
+
+ def __invert__(self):
+ return self.inverse()
+
+ def sqrt(self):
+ return FElement(self.field.sqrt(self.f), self.field)
+
+ def trace(self):
+ return FElement(self.field.trace(self.f), self.field)
+
+ def half_trace(self):
+ return FElement(self.field.half_trace(self.f), self.field)
+
+ def __pow__(self, power, modulo=None):
+ return FElement(self.field.exponentiate(self.f, power), self.field)
+
+ def __str__(self):
+ return str(int(self))
+
+ def __repr__(self):
+ return str(self)
+
+ def __int__(self):
+ return self.f
+
+ def __eq__(self, other):
+ if not isinstance(other, FElement):
+ return False
+ if self.field != other.field:
+ return False
+ return self.f == other.f
+
+
+class Curve(object):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, field, a, b, group, name=None):
+ self.field = field
+ if name is None:
+ name = "undefined"
+ self.name = name
+ self.a = a
+ self.b = b
+ self.group = group
+ self.g = Point(self, self.group.g[0], self.group.g[1])
+
+ @abc.abstractmethod
+ def is_singular(self):
+ ...
+
+ @abc.abstractmethod
+ def on_curve(self, x, y):
+ ...
+
+ @abc.abstractmethod
+ def add(self, x1, y1, x2, y2):
+ ...
+
+ @abc.abstractmethod
+ def dbl(self, x, y):
+ ...
+
+ @abc.abstractmethod
+ def neg(self, x, y):
+ ...
+
+ @abc.abstractmethod
+ def encode_point(self, point, compressed=False):
+ ...
+
+ @abc.abstractmethod
+ def decode_point(self, byte_data):
+ ...
+
+ def bit_size(self):
+ return self.group.n.bit_length()
+
+ def byte_size(self):
+ return (self.bit_size() + 7) // 8
+
+ @abc.abstractmethod
+ def field_size(self):
+ ...
+
+ def __eq__(self, other):
+ if not isinstance(other, Curve):
+ return False
+ return self.field == other.field and self.a == other.a and self.b == other.b and self.group == other.group
+
+ def __repr__(self):
+ return str(self)
+
+
+class CurveFp(Curve):
+ def is_singular(self):
+ return (4 * self.a ** 3 + 27 * self.b ** 2) == 0
+
+ def on_curve(self, x, y):
+ return (y ** 2 - x ** 3 - self.a * x - self.b) == 0
+
+ def add(self, x1, y1, x2, y2):
+ lm = (y2 - y1) / (x2 - x1)
+ x3 = lm ** 2 - x1 - x2
+ y3 = lm * (x1 - x3) - y1
+ return x3, y3
+
+ def dbl(self, x, y):
+ lm = (3 * x ** 2 + self.a) / (2 * y)
+ x3 = lm ** 2 - (2 * x)
+ y3 = lm * (x - x3) - y
+ return x3, y3
+
+ def mul(self, k, x, y, z=1):
+ def _add(x1, y1, z1, x2, y2, z2):
+ yz = y1 * z2
+ xz = x1 * z2
+ zz = z1 * z2
+ u = y2 * z1 - yz
+ uu = u ** 2
+ v = x2 * z1 - xz
+ vv = v ** 2
+ vvv = v * vv
+ r = vv * xz
+ a = uu * zz - vvv - 2 * r
+ x3 = v * a
+ y3 = u * (r - a) - vvv * yz
+ z3 = vvv * zz
+ return x3, y3, z3
+
+ def _dbl(x1, y1, z1):
+ xx = x1 ** 2
+ zz = z1 ** 2
+ w = self.a * zz + 3 * xx
+ s = 2 * y1 * z1
+ ss = s ** 2
+ sss = s * ss
+ r = y1 * s
+ rr = r ** 2
+ b = (x1 + r) ** 2 - xx - rr
+ h = w ** 2 - 2 * b
+ x3 = h * s
+ y3 = w * (b - h) - 2 * rr
+ z3 = sss
+ return x3, y3, z3
+ r0 = (x, y, z)
+ r1 = _dbl(x, y, z)
+ for i in range(k.bit_length() - 2, -1, -1):
+ if k & (1 << i):
+ r0 = _add(*r0, *r1)
+ r1 = _dbl(*r1)
+ else:
+ r1 = _add(*r0, *r1)
+ r0 = _dbl(*r0)
+ rx, ry, rz = r0
+ rzi = ~rz
+ return rx * rzi, ry * rzi
+
+ def neg(self, x, y):
+ return x, -y
+
+ def field_size(self):
+ return self.field.bit_length()
+
+ def encode_point(self, point, compressed=False):
+ byte_size = (self.field_size() + 7) // 8
+ if not compressed:
+ return bytes((0x04,)) + int(point.x).to_bytes(byte_size, byteorder="big") + int(
+ point.y).to_bytes(byte_size, byteorder="big")
+ else:
+ yp = int(point.y) & 1
+ pc = bytes((0x02 | yp,))
+ return pc + int(point.x).to_bytes(byte_size, byteorder="big")
+
+ def decode_point(self, byte_data):
+ if byte_data[0] == 0 and len(byte_data) == 1:
+ return Inf(self)
+ byte_size = (self.field_size() + 7) // 8
+ if byte_data[0] in (0x04, 0x06):
+ if len(byte_data) != 1 + byte_size * 2:
+ raise ValueError
+ x = Mod(int.from_bytes(byte_data[1:byte_size + 1], byteorder="big"), self.field)
+ y = Mod(int.from_bytes(byte_data[byte_size + 1:], byteorder="big"), self.field)
+ return Point(self, x, y)
+ elif byte_data[0] in (0x02, 0x03):
+ if len(byte_data) != 1 + byte_size:
+ raise ValueError
+ x = Mod(int.from_bytes(byte_data[1:byte_size + 1], byteorder="big"), self.field)
+ rhs = x ** 3 + self.a * x + self.b
+ sqrt = rhs.sqrt()
+ yp = byte_data[0] & 1
+ if int(sqrt) & 1 == yp:
+ return Point(self, x, sqrt)
+ else:
+ return Point(self, x, self.field - sqrt)
+ raise ValueError
+
+ def __str__(self):
+ return "\"{}\": y^2 = x^3 + {}x + {} over {}".format(self.name, self.a, self.b, self.field)
+
+
+class CurveF2m(Curve):
+ def is_singular(self):
+ return self.b == 0
+
+ def on_curve(self, x, y):
+ return (y ** 2 + x * y - x ** 3 - self.a * x ^ 2 - self.b) == 0
+
+ def add(self, x1, y1, x2, y2):
+ lm = (y1 + y2) / (x1 + x2)
+ x3 = lm ** 2 + lm + x1 + x2 + self.a
+ y3 = lm * (x1 + x3) + x3 + y1
+ return x3, y3
+
+ def dbl(self, x, y):
+ lm = x + y / x
+ x3 = lm ** 2 + lm + self.a
+ y3 = x ** 2 + lm * x3 + x3
+ return x3, y3
+
+ def mul(self, k, x, y, z=1):
+ def _add(x1, y1, z1, x2, y2, z2):
+ a = x1 * z2
+ b = x2 * z1
+ c = a ** 2
+ d = b ** 2
+ e = a + b
+ f = c + d
+ g = y1 * (z2 ** 2)
+ h = y2 * (z1 ** 2)
+ i = g + h
+ j = i * e
+ z3 = f * z1 * z2
+ x3 = a * (h + d) + b * (c + g)
+ y3 = (a * j + f * g) * f + (j + z3) * x3
+ return x3, y3, z3
+
+ def _dbl(x1, y1, z1):
+ a = x1 * z1
+ b = x1 * x1
+ c = b + y1
+ d = a * c
+ z3 = a * a
+ x3 = c ** 2 + d + self.a * z3
+ y3 = (z3 + d) * x3 + b ** 2 * z3
+ return x3, y3, z3
+ r0 = (x, y, z)
+ r1 = _dbl(x, y, z)
+ for i in range(k.bit_length() - 2, -1, -1):
+ if k & (1 << i):
+ r0 = _add(*r0, *r1)
+ r1 = _dbl(*r1)
+ else:
+ r1 = _add(*r0, *r1)
+ r0 = _dbl(*r0)
+ rx, ry, rz = r0
+ rzi = ~rz
+ return rx * rzi, ry * (rzi ** 2)
+
+ def neg(self, x, y):
+ return x, x + y
+
+ def field_size(self):
+ return self.field.n
+
+ def encode_point(self, point, compressed=False):
+ byte_size = (self.field_size() + 7) // 8
+ if not compressed:
+ return bytes((0x04,)) + int(point.x).to_bytes(byte_size, byteorder="big") + int(
+ point.y).to_bytes(byte_size, byteorder="big")
+ else:
+ if int(point.x) == 0:
+ yp = 0
+ else:
+ yp = int(point.y * point.x.inverse())
+ pc = bytes((0x02 | yp,))
+ return pc + int(point.x).to_bytes(byte_size, byteorder="big")
+
+ def decode_point(self, byte_data):
+ if byte_data[0] == 0 and len(byte_data) == 1:
+ return Inf(self)
+ byte_size = (self.field_size() + 7) // 8
+ if byte_data[0] in (0x04, 0x06):
+ if len(byte_data) != 1 + byte_size * 2:
+ raise ValueError
+ x = FElement(int.from_bytes(byte_data[1:byte_size + 1], byteorder="big"), self.field)
+ y = FElement(int.from_bytes(byte_data[byte_size + 1:], byteorder="big"), self.field)
+ return Point(self, x, y)
+ elif byte_data[0] in (0x02, 0x03):
+ if self.field.n % 2 != 1:
+ raise NotImplementedError
+ x = FElement(int.from_bytes(byte_data[1:byte_size + 1], byteorder="big"), self.field)
+ yp = byte_data[0] & 1
+ if int(x) == 0:
+ y = self.b ** (2 ** (self.field.n - 1))
+ else:
+ rhs = x + self.a + self.b * x ** (-2)
+ z = rhs.half_trace()
+ if z ** 2 + z != rhs:
+ raise ValueError
+ if int(z) & 1 != yp:
+ z += 1
+ y = x * z
+ return Point(self, x, y)
+ raise ValueError
+
+ def __str__(self):
+ return "\"{}\" => y^2 + xy = x^3 + {}x^2 + {} over {}".format(self.name, self.a, self.b,
+ self.field)
+
+
+class SubGroup(object):
+ def __init__(self, g, n, h):
+ self.g = g
+ self.n = n
+ self.h = h
+
+ def __eq__(self, other):
+ if not isinstance(other, SubGroup):
+ return False
+ return self.g == other.g and self.n == other.n and self.h == other.h
+
+ def __str__(self):
+ return "Subgroup => generator {}, order: {}, cofactor: {}".format(self.g, self.n, self.h)
+
+ def __repr__(self):
+ return str(self)
+
+
+class Inf(object):
+ def __init__(self, curve, x=None, y=None):
+ self.x = x
+ self.y = y
+ self.curve = curve
+
+ def __eq__(self, other):
+ if not isinstance(other, Inf):
+ return False
+ return self.curve == other.curve
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __neg__(self):
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, Inf):
+ return Inf(self.curve)
+ if isinstance(other, Point):
+ return other
+ raise TypeError(
+ "Unsupported operand type(s) for +: '%s' and '%s'" % (self.__class__.__name__,
+ other.__class__.__name__))
+
+ def __radd__(self, other):
+ return self + other
+
+ def __sub__(self, other):
+ if isinstance(other, Inf):
+ return Inf(self.curve)
+ if isinstance(other, Point):
+ return other
+ raise TypeError(
+ "Unsupported operand type(s) for +: '%s' and '%s'" % (self.__class__.__name__,
+ other.__class__.__name__))
+
+ def __str__(self):
+ return "{} on {}".format(self.__class__.__name__, self.curve)
+
+ def __repr__(self):
+ return str(self)
+
+
+class Point(object):
+ def __init__(self, curve, x, y):
+ self.curve = curve
+ self.x = x
+ self.y = y
+
+ def __eq__(self, other):
+ if not isinstance(other, Point):
+ return False
+ return self.x == other.x and self.y == other.y and self.curve == other.curve
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __neg__(self):
+ return Point(self.curve, *self.curve.neg(self.x, self.y))
+
+ def __add__(self, other):
+ if isinstance(other, Inf):
+ return self
+ if isinstance(other, Point):
+ if self.curve != other.curve:
+ raise ValueError("Cannot add points belonging to different curves")
+ if self == -other:
+ return Inf(self.curve)
+ elif self == other:
+ return Point(self.curve, *self.curve.dbl(self.x, self.y))
+ else:
+ return Point(self.curve, *self.curve.add(self.x, self.y, other.x, other.y))
+ else:
+ raise TypeError(
+ "Unsupported operand type(s) for +: '{}' and '{}'".format(
+ self.__class__.__name__,
+ other.__class__.__name__))
+
+ def __radd__(self, other):
+ return self + other
+
+ def __sub__(self, other):
+ return self + (-other)
+
+ def __rsub__(self, other):
+ return self - other
+
+ def __mul__(self, other):
+ if isinstance(other, int):
+ if other % self.curve.group.n == 0:
+ return Inf(self.curve)
+ if other < 0:
+ other = -other
+ addend = -self
+ else:
+ addend = self
+ if hasattr(self.curve, "mul") and callable(getattr(self.curve, "mul")):
+ return Point(self.curve, *self.curve.mul(other, addend.x, addend.y))
+ else:
+ result = Inf(self.curve)
+ # Iterate over all bits starting by the LSB
+ for bit in reversed([int(i) for i in bin(abs(other))[2:]]):
+ if bit == 1:
+ result += addend
+ addend += addend
+ return result
+ else:
+ raise TypeError(
+ "Unsupported operand type(s) for *: '%s' and '%s'" % (other.__class__.__name__,
+ self.__class__.__name__))
+
+ def __rmul__(self, other):
+ return self * other
+
+ def __str__(self):
+ return "({}, {}) on {}".format(self.x, self.y, self.curve)
+
+ def __repr__(self):
+ return str(self)
+
+
+def load_curve(file, name=None):
+ data = file.read()
+ parts = list(map(lambda x: int(x, 16), data.split(",")))
+ if len(parts) == 7:
+ p, a, b, gx, gy, n, h = parts
+ g = (Mod(gx, p), Mod(gy, p))
+ group = SubGroup(g, n, h)
+ return CurveFp(p, Mod(a, p), Mod(b, p), group, name)
+ elif len(parts) == 10:
+ m, e1, e2, e3, a, b, gx, gy, n, h = parts
+ poly = [m, e1, e2, e3, 0]
+ field = FField(m, poly)
+ g = (FElement(gx, field), FElement(gy, field))
+ group = SubGroup(g, n, h)
+ return CurveF2m(field, FElement(a, field), FElement(b, field), group, name)
+ else:
+ raise ValueError("Invalid curve data")
+
+
+def get_curve(idd):
+ cat, i = idd.split("/")
+ with open(path.join("..", "src", "cz", "crcs", "ectester", "data", cat, i + ".csv"), "r") as f:
+ return load_curve(f, i)
+