aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/ec
diff options
context:
space:
mode:
authorJ08nY2023-08-01 19:01:56 +0200
committerJ08nY2023-08-01 19:01:56 +0200
commit6dcf7835175a9c3c575ce484740d340a7f6f8f68 (patch)
tree8d32e8499a8a904e58d64aae242b97dfac3a1c84 /pyecsca/ec
parent60e14db3e932c7ba6c640e1059cf6f06215de3a1 (diff)
downloadpyecsca-6dcf7835175a9c3c575ce484740d340a7f6f8f68.tar.gz
pyecsca-6dcf7835175a9c3c575ce484740d340a7f6f8f68.tar.zst
pyecsca-6dcf7835175a9c3c575ce484740d340a7f6f8f68.zip
Add multiplication-by-n polynomial computation to divpoly.
Diffstat (limited to 'pyecsca/ec')
-rw-r--r--pyecsca/ec/divpoly.py263
1 files changed, 263 insertions, 0 deletions
diff --git a/pyecsca/ec/divpoly.py b/pyecsca/ec/divpoly.py
new file mode 100644
index 0000000..2f2b635
--- /dev/null
+++ b/pyecsca/ec/divpoly.py
@@ -0,0 +1,263 @@
+from typing import Tuple, Dict, Union, Set, Mapping
+
+from sympy import symbols, FF, Poly
+import networkx as nx
+
+from .curve import EllipticCurve
+from .mod import Mod
+from .model import ShortWeierstrassModel
+
+
+def values(*ns: int) -> Mapping[int, Tuple[int, ...]]:
+ done: Set[int] = set()
+ vals = {}
+ todo: Set[int] = set()
+ todo.update(ns)
+ while todo:
+ val = todo.pop()
+ if val in done:
+ continue
+ new: Tuple[int, ...] = ()
+ if val == -2:
+ new = (-1,)
+ elif val == -1:
+ pass
+ elif val < 0:
+ raise ValueError(f"bad {val}")
+ elif val in (0, 1, 2, 3):
+ pass
+ elif val == 4:
+ new = (-2, 3)
+ elif val % 2 == 0:
+ m = (val - 2) // 2
+ new = (m + 1, m + 3, m, m - 1, m + 2)
+ else:
+ m = (val - 1) // 2
+ if m % 2 == 0:
+ new = (-2, m + 2, m, m - 1, m + 1)
+ else:
+ new = (m + 2, m, -2, m - 1, m + 1)
+ if new:
+ todo.update(new)
+ vals[val] = new
+ done.add(val)
+ return vals
+
+
+def dep_graph(*ns: int):
+ g = nx.DiGraph()
+ vals = values(*ns)
+ for k, v in vals.items():
+ if v:
+ for e in v:
+ g.add_edge(k, e)
+ else:
+ g.add_node(k)
+ return g, vals
+
+
+def dep_map(*ns: int):
+ g, vals = dep_graph(*ns)
+ current: Set[int] = set()
+ ls = []
+ for vert in nx.lexicographical_topological_sort(g, key=lambda v: -sum(g[v].keys())):
+ if vert in current:
+ current.remove(vert)
+ ls.append((vert, set(current)))
+ current.update(vals[vert])
+ ls.reverse()
+ return ls, vals
+
+
+def a_invariants(curve: EllipticCurve) -> Tuple[Mod, ...]:
+ """
+ Compute the a-invariants of the curve.
+
+ :param curve: The elliptic curve (only ShortWeierstrass model).
+ :return: A tuple of 5 a-invariants (a1, a2, a3, a4, a6).
+ """
+ if isinstance(curve.model, ShortWeierstrassModel):
+ a1 = Mod(0, curve.prime)
+ a2 = Mod(0, curve.prime)
+ a3 = Mod(0, curve.prime)
+ a4 = curve.parameters["a"]
+ a6 = curve.parameters["b"]
+ return a1, a2, a3, a4, a6
+ else:
+ raise NotImplementedError
+
+
+def b_invariants(curve: EllipticCurve) -> Tuple[Mod, ...]:
+ """
+ Compute the b-invariants of the curve.
+
+ :param curve: The elliptic curve (only ShortWeierstrass model).
+ :return: A tuple of 4 b-invariants (b2, b4, b6, b8).
+ """
+ if isinstance(curve.model, ShortWeierstrassModel):
+ a1, a2, a3, a4, a6 = a_invariants(curve)
+ return (a1 * a1 + 4 * a2,
+ a1 * a3 + 2 * a4,
+ a3 ** 2 + 4 * a6,
+ a1 ** 2 * a6 + 4 * a2 * a6 - a1 * a3 * a4 + a2 * a3 ** 2 - a4 ** 2)
+ else:
+ raise NotImplementedError
+
+
+def divpoly0(curve: EllipticCurve, *ns: int) -> Mapping[int, Poly]:
+ """
+ Basically sagemath's division_polynomial_0 but more clever memory management.
+
+ As sagemath says:
+
+ Return the `n^{th}` torsion (division) polynomial, without
+ the 2-torsion factor if `n` is even, as a polynomial in `x`.
+
+ These are the polynomials `g_n` defined in [MT1991]_, but with
+ the sign flipped for even `n`, so that the leading coefficient is
+ always positive.
+
+ :param curve: The elliptic curve.
+ :param ns: The values to compute the polynomial for.
+ :return:
+ """
+ xs = symbols("x")
+
+ K = FF(curve.prime)
+ Kx = lambda r: Poly(r, xs, domain=K) # noqa
+
+ x = Kx(xs)
+
+ b2, b4, b6, b8 = map(lambda b: Kx(int(b)), b_invariants(curve))
+ ls, vals = dep_map(*ns)
+
+ mem: Dict[int, Poly] = {}
+ for i, keep in ls:
+ if i == -2:
+ val = mem[-1] ** 2
+ elif i == -1:
+ val = Kx(4) * x ** 3 + b2 * x ** 2 + Kx(2) * b4 * x + b6
+ elif i == 0:
+ val = Kx(0)
+ elif i < 0:
+ raise ValueError("n must be a positive integer (or -1 or -2)")
+ elif i == 1 or i == 2:
+ val = Kx(1)
+ elif i == 3:
+ val = Kx(3) * x ** 4 + b2 * x ** 3 + Kx(3) * b4 * x ** 2 + Kx(3) * b6 * x + b8
+ elif i == 4:
+ val = -mem[-2] + (Kx(6) * x ** 2 + b2 * x + b4) * mem[3]
+ elif i % 2 == 0:
+ m = (i - 2) // 2
+ val = mem[m + 1] * (mem[m + 3] * mem[m] ** 2 - mem[m - 1] * mem[m + 2] ** 2)
+ else:
+ m = (i - 1) // 2
+ if m % 2 == 0:
+ val = mem[-2] * mem[m + 2] * mem[m] ** 3 - mem[m - 1] * mem[m + 1] ** 3
+ else:
+ val = mem[m + 2] * mem[m] ** 3 - mem[-2] * mem[m - 1] * mem[m + 1] ** 3
+ for dl in set(mem.keys()).difference(keep).difference(ns):
+ del mem[dl]
+ mem[i] = val
+
+ return mem
+
+
+def divpoly(curve: EllipticCurve, n: int, two_torsion_multiplicity: int = 2) -> Poly:
+ """
+ Compute the n-th division polynomial.
+
+ :param curve:
+ :param n:
+ :param two_torsion_multiplicity:
+ :return:
+ """
+ f: Poly = divpoly0(curve, n)[n]
+ a1, a2, a3, a4, a6 = a_invariants(curve)
+ xs, ys = symbols("x y")
+ x = Poly(xs, xs, domain=f.domain)
+ y = Poly(ys, ys, domain=f.domain)
+
+ if two_torsion_multiplicity == 0:
+ return f
+ elif two_torsion_multiplicity == 1:
+ if n % 2 == 0:
+ Kxy = lambda r: Poly(r, xs, ys, domain=f.domain) # noqa
+ return Kxy(f) * (Kxy(2) * y + Kxy(a1) * x + Kxy(a3))
+ else:
+ return f
+ elif two_torsion_multiplicity == 2:
+ if n % 2 == 0:
+ return f * divpoly0(curve, -1)[-1]
+ else:
+ return f
+
+
+def mult_by_n(curve: EllipticCurve, n: int) -> Tuple[Tuple[Poly, Poly], Tuple[Poly, Poly]]:
+ """
+ Compute the multiplication-by-n map on an elliptic curve.
+
+ :param curve: Curve to compute on.
+ :param n: Scalar.
+ :return:
+ """
+ xs, ys = symbols("x y")
+ K = FF(curve.prime)
+ x = Poly(xs, xs, domain=K)
+ y = Poly(ys, ys, domain=K)
+ Kxy = lambda r: Poly(r, xs, ys, domain=K) # noqa
+
+ if n == 1:
+ return x
+
+ a1, a2, a3, a4, a6 = a_invariants(curve)
+
+ polys = divpoly0(curve, -2, -1, n - 1, n, n + 1, n + 2)
+ mx_denom = polys[n] ** 2
+ if n % 2 == 0:
+ mx_num = x * polys[-1] * polys[n] ** 2 - polys[n - 1] * polys[n + 1]
+ mx_denom *= polys[-1]
+ else:
+ mx_num = x * polys[n] ** 2 - polys[-1] * polys[n - 1] * polys[n + 1]
+
+ # Alternative that makes the denominator monic by dividing the
+ # numerator by the leading coefficient. Sage does this
+ # simplification when asking for multiplication_by_m with the
+ # x-only=True, as then the poly is an univariate object.
+ # lc = K(mx_denom.LC())
+ # mx = (mx_num.quo(lc), mx_denom.monic())
+ mx = (mx_num, mx_denom)
+
+ # The following lines compute
+ # my = ((2*y+a1*x+a3)*mx.derivative(x)/m - a1*mx-a3)/2
+ # just as sage does, but using sympy and step-by-step
+ # tracking the numerator and denominator of the fraction.
+
+ # mx.derivative()
+ mxd_num = mx[1] * mx[0].diff() - mx[0] * mx[1].diff()
+ mxd_denom = mx[1] ** 2
+
+ # mx.derivative()/m
+ mxd_dn_num = mxd_num
+ mxd_dn_denom = mxd_denom * Kxy(n)
+
+ # (2*y+a1*x+a3)*mx.derivative(x)/m
+ mxd_full_num = mxd_dn_num * (Kxy(2) * y + Kxy(a1) * x + Kxy(a3))
+ mxd_full_denom = mxd_dn_denom
+
+ # a1*mx
+ a1mx_num = (Kxy(a1) * mx[0]).quo(Kxy(2))
+ a1mx_denom = mx[1] # noqa
+
+ # a3
+ a3_num = Kxy(a3) * mx[1]
+ a3_denom = mx[1] # noqa
+
+ # The mx.derivative part has a different denominator, basically mx[1]^2 * m
+ # so the rest needs to be multiplied by this factor when subtracitng.
+ mxd_fact = mx[1] * n
+
+ my_num = (mxd_full_num - a1mx_num * mxd_fact - a3_num * mxd_fact)
+ my_denom = mxd_full_denom * Kxy(2)
+ my = (my_num, my_denom)
+ return mx, my