aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/references.rst1
-rw-r--r--pyecsca/ec/key_agreement.py150
-rw-r--r--pyecsca/ec/mult/ladder.py80
-rw-r--r--pyecsca/ec/params.py7
-rw-r--r--pyecsca/sca/target/leakage.py4
-rw-r--r--test/conftest.py7
-rw-r--r--test/ec/test_key_agreement.py147
-rw-r--r--test/ec/test_mult.py88
-rw-r--r--test/sca/test_tree.py1
9 files changed, 424 insertions, 61 deletions
diff --git a/docs/references.rst b/docs/references.rst
index 9533448..555431f 100644
--- a/docs/references.rst
+++ b/docs/references.rst
@@ -18,3 +18,4 @@
.. [SG14] Shay Gueron & Vlad Krasnov. Fast prime field elliptic-curve cryptography with 256-bit primes, https://link.springer.com/article/10.1007/s13389-014-0090-x
.. [B51] Andrew D. Booth. A signed binary multiplication technique.
.. [M61] O.L. Macsorley. High-speed arithmetic in binary computers.
+.. [RFC7748] A. Langley, M. Hamburg, S. Turner, https://datatracker.ietf.org/doc/html/rfc7748
diff --git a/pyecsca/ec/key_agreement.py b/pyecsca/ec/key_agreement.py
index fe07b1f..c15c0b6 100644
--- a/pyecsca/ec/key_agreement.py
+++ b/pyecsca/ec/key_agreement.py
@@ -1,13 +1,16 @@
-"""Provides an implementation of ECDH (Elliptic Curve Diffie-Hellman)."""
+"""Provides an implementation of ECDH (Elliptic Curve Diffie-Hellman) and XDH (X25519, X448)."""
+
import hashlib
+from abc import abstractmethod, ABC
from typing import Optional, Any
from public import public
from pyecsca.ec.context import ResultAction
from pyecsca.ec.mod import Mod
+from pyecsca.ec.model import MontgomeryModel
from pyecsca.ec.mult import ScalarMultiplier
-from pyecsca.ec.params import DomainParameters
+from pyecsca.ec.params import DomainParameters, get_params
from pyecsca.ec.point import Point
@@ -38,7 +41,126 @@ class ECDHAction(ResultAction):
@public
-class KeyAgreement:
+class XDHAction(ResultAction):
+ """XDH key exchange."""
+
+ params: DomainParameters
+ privkey: int
+ pubkey: Point
+
+ def __init__(self, params: DomainParameters, privkey: int, pubkey: Point):
+ super().__init__()
+ self.params = params
+ self.privkey = privkey
+ self.pubkey = pubkey
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.params}, {self.privkey}, {self.pubkey})"
+ )
+
+
+@public
+class KeyAgreement(ABC):
+ """An abstract EC-based key agreement."""
+
+ @abstractmethod
+ def perform_raw(self) -> Point:
+ """
+ Perform the scalar-multiplication of the key agreement.
+
+ :return: The shared point.
+ """
+ ...
+
+ @abstractmethod
+ def perform(self) -> bytes:
+ """
+ Perform the key agreement operation.
+
+ :return: The shared secret.
+ """
+ ...
+
+
+@public
+class XDH(KeyAgreement):
+ def __init__(
+ self,
+ mult: ScalarMultiplier,
+ params: DomainParameters,
+ pubkey: Point,
+ privkey: int,
+ bits: int,
+ bytes: int,
+ ):
+ if "scl" not in mult.formulas:
+ raise ValueError("ScalarMultiplier needs to have the scaling formula.")
+ if not isinstance(params.curve.model, MontgomeryModel):
+ raise ValueError("Invalid curve model.")
+ self.mult = mult
+ self.params = params
+ self.pubkey = pubkey
+ self.privkey = privkey
+ self.bits = bits
+ self.bytes = bytes
+ self.mult.init(self.params, self.pubkey)
+
+ def clamp(self, scalar: int) -> int:
+ return scalar
+
+ def perform_raw(self) -> Point:
+ clamped = self.clamp(self.privkey)
+ return self.mult.multiply(clamped)
+
+ def perform(self) -> bytes:
+ with XDHAction(self.params, self.privkey, self.pubkey) as action:
+ point = self.perform_raw()
+ return action.exit(int(point.X).to_bytes(self.bytes, "little"))
+
+
+@public
+class X25519(XDH):
+ """
+ X25519 (or Curve25519) from [RFC7748]_.
+
+ .. warning::
+ You need to clear the top bit of the point coordinate (pubkey) before converting to a point.
+ """
+
+ def __init__(self, mult: ScalarMultiplier, pubkey: Point, privkey: int):
+ curve25519 = get_params(
+ "other", "Curve25519", pubkey.coordinate_model.name, infty=False
+ )
+ super().__init__(mult, curve25519, pubkey, privkey, 255, 32)
+
+ def clamp(self, scalar: int) -> int:
+ scalar &= ~7
+ scalar &= ~(128 << 8 * 31)
+ scalar |= 64 << 8 * 31
+ return scalar
+
+
+@public
+class X448(XDH):
+ """
+ X448 (or Curve448) from [RFC7748]_.
+ """
+
+ def __init__(self, mult: ScalarMultiplier, pubkey: Point, privkey: int):
+ curve448 = get_params(
+ "other", "Curve448", pubkey.coordinate_model.name, infty=False
+ )
+ super().__init__(mult, curve448, pubkey, privkey, 448, 56)
+
+ def clamp(self, scalar: int) -> int:
+ scalar &= ~3
+ scalar |= 128 << 8 * 55
+ return scalar
+
+
+@public
+class ECDH(KeyAgreement):
"""EC based key agreement primitive (ECDH)."""
mult: ScalarMultiplier
@@ -63,20 +185,10 @@ class KeyAgreement:
self.mult.init(self.params, self.pubkey)
def perform_raw(self) -> Point:
- """
- Perform the scalar-multiplication of the key agreement.
-
- :return: The shared point.
- """
point = self.mult.multiply(int(self.privkey))
return point.to_affine()
def perform(self) -> bytes:
- """
- Perform the key agreement operation.
-
- :return: The shared secret.
- """
with ECDHAction(
self.params, self.hash_algo, self.privkey, self.pubkey
) as action:
@@ -91,7 +203,7 @@ class KeyAgreement:
@public
-class ECDH_NONE(KeyAgreement):
+class ECDH_NONE(ECDH):
"""Raw x-coordinate ECDH."""
def __init__(
@@ -105,7 +217,7 @@ class ECDH_NONE(KeyAgreement):
@public
-class ECDH_SHA1(KeyAgreement):
+class ECDH_SHA1(ECDH):
"""ECDH with SHA1 of x-coordinate."""
def __init__(
@@ -119,7 +231,7 @@ class ECDH_SHA1(KeyAgreement):
@public
-class ECDH_SHA224(KeyAgreement):
+class ECDH_SHA224(ECDH):
"""ECDH with SHA224 of x-coordinate."""
def __init__(
@@ -133,7 +245,7 @@ class ECDH_SHA224(KeyAgreement):
@public
-class ECDH_SHA256(KeyAgreement):
+class ECDH_SHA256(ECDH):
"""ECDH with SHA256 of x-coordinate."""
def __init__(
@@ -147,7 +259,7 @@ class ECDH_SHA256(KeyAgreement):
@public
-class ECDH_SHA384(KeyAgreement):
+class ECDH_SHA384(ECDH):
"""ECDH with SHA384 of x-coordinate."""
def __init__(
@@ -161,7 +273,7 @@ class ECDH_SHA384(KeyAgreement):
@public
-class ECDH_SHA512(KeyAgreement):
+class ECDH_SHA512(ECDH):
"""ECDH with SHA512 of x-coordinate."""
def __init__(
diff --git a/pyecsca/ec/mult/ladder.py b/pyecsca/ec/mult/ladder.py
index 3e9e426..f2939af 100644
--- a/pyecsca/ec/mult/ladder.py
+++ b/pyecsca/ec/mult/ladder.py
@@ -72,7 +72,7 @@ class LadderMultiplier(ScalarMultiplier):
if self.complete:
p0 = copy(self._params.curve.neutral)
p1 = self._point
- top = self._params.order.bit_length() - 1
+ top = self._params.full_order.bit_length() - 1
else:
p0 = copy(q)
p1 = self._dbl(q)
@@ -88,6 +88,82 @@ class LadderMultiplier(ScalarMultiplier):
@public
+class SwapLadderMultiplier(ScalarMultiplier):
+ """
+ Montgomery ladder multiplier, using a three input, two output ladder formula.
+
+ Optionally takes a doubling formula, and if `complete` is false, it requires one.
+
+ :param short_circuit: Whether the use of formulas will be guarded by short-circuit on inputs
+ of the point at infinity.
+ :param complete: Whether it starts processing at full order-bit-length.
+ """
+
+ requires = {LadderFormula}
+ optionals = {DoublingFormula, ScalingFormula}
+ complete: bool
+ """Whether it starts processing at full order-bit-length."""
+
+ def __init__(
+ self,
+ ladd: LadderFormula,
+ dbl: Optional[DoublingFormula] = None,
+ scl: Optional[ScalingFormula] = None,
+ complete: bool = True,
+ short_circuit: bool = True,
+ ):
+ super().__init__(short_circuit=short_circuit, ladd=ladd, dbl=dbl, scl=scl)
+ self.complete = complete
+ if dbl is None:
+ if not complete:
+ raise ValueError("When complete is not set SwapLadderMultiplier requires a doubling formula.")
+ if short_circuit: # complete = True
+ raise ValueError("When short_circuit is set SwapLadderMultiplier requires a doubling formula.")
+
+ def __hash__(self):
+ return hash((LadderMultiplier, super().__hash__(), self.complete))
+
+ def __eq__(self, other):
+ if not isinstance(other, LadderMultiplier):
+ return False
+ return (
+ self.formulas == other.formulas
+ and self.short_circuit == other.short_circuit
+ and self.complete == other.complete
+ )
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, complete={self.complete})"
+
+ def multiply(self, scalar: int) -> Point:
+ if not self._initialized:
+ raise ValueError("ScalarMultiplier not initialized.")
+ with ScalarMultiplicationAction(self._point, scalar) as action:
+ if scalar == 0:
+ return action.exit(copy(self._params.curve.neutral))
+ q = self._point
+ if self.complete:
+ p0 = copy(self._params.curve.neutral)
+ p1 = self._point
+ top = self._params.full_order.bit_length() - 1
+ else:
+ p0 = copy(q)
+ p1 = self._dbl(q)
+ top = scalar.bit_length() - 2
+ prev_bit = 0
+ for i in range(top, -1, -1):
+ k = (scalar & (1 << i)) >> i
+ swap = prev_bit ^ k
+ prev_bit = k
+ p0, p1 = (p1, p0) if swap else (p0, p1)
+ p0, p1 = self._ladd(q, p0, p1)
+ p0, p1 = (p1, p0) if prev_bit else (p0, p1)
+ if "scl" in self.formulas:
+ p0 = self._scl(p0)
+ return action.exit(p0)
+
+
+@public
class SimpleLadderMultiplier(ScalarMultiplier):
"""
Montgomery ladder multiplier, using addition and doubling formulas.
@@ -200,7 +276,7 @@ class DifferentialLadderMultiplier(ScalarMultiplier):
if scalar == 0:
return action.exit(copy(self._params.curve.neutral))
if self.complete:
- top = self._params.order.bit_length() - 1
+ top = self._params.full_order.bit_length() - 1
else:
top = scalar.bit_length() - 1
q = self._point
diff --git a/pyecsca/ec/params.py b/pyecsca/ec/params.py
index 2e437a2..627d457 100644
--- a/pyecsca/ec/params.py
+++ b/pyecsca/ec/params.py
@@ -3,6 +3,7 @@ Provides functions for obtaining domain parameters from the `std-curves <https:/
It also provides a domain parameter class and a class for a whole category of domain parameters.
"""
+
import json
import csv
from sympy import Poly, FF, symbols
@@ -115,6 +116,10 @@ class DomainParameters:
curve, generator, self.order, self.cofactor, self.name, self.category
)
+ @property
+ def full_order(self) -> int:
+ return self.order * self.cofactor
+
def __get_name(self):
if self.name and self.category:
return f"{self.category}/{self.name}"
@@ -478,7 +483,7 @@ def get_params(
if curve["name"] == name:
break
else:
- raise ValueError(f"Curve {name} not found in category {category}.")
+ raise ValueError(f"Curve {name} not found in category {category}.") # noqa
return _create_params(curve, coords, infty)
diff --git a/pyecsca/sca/target/leakage.py b/pyecsca/sca/target/leakage.py
index 74838b8..6487c56 100644
--- a/pyecsca/sca/target/leakage.py
+++ b/pyecsca/sca/target/leakage.py
@@ -9,7 +9,7 @@ from pyecsca.ec.params import DomainParameters
from pyecsca.ec.point import Point
from pyecsca.ec.mult import ScalarMultiplier
from pyecsca.ec.key_generation import KeyGeneration
-from pyecsca.ec.key_agreement import KeyAgreement
+from pyecsca.ec.key_agreement import ECDH
from pyecsca.ec.signature import Signature, SignatureResult
from pyecsca.ec.formula import FormulaAction
from pyecsca.ec.context import DefaultContext, local
@@ -121,7 +121,7 @@ class LeakageTarget(Target):
if self.privkey is None:
raise ValueError("Missing privkey")
with local(DefaultContext()) as ctx:
- ecdh = KeyAgreement(
+ ecdh = ECDH(
self.mult, self.params, other_pubkey, self.privkey, hash_algo
)
shared_secret = ecdh.perform()
diff --git a/test/conftest.py b/test/conftest.py
index 5c3f855..e12beb3 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -10,7 +10,12 @@ def secp128r1() -> DomainParameters:
@pytest.fixture(scope="session")
def curve25519() -> DomainParameters:
- return get_params("other", "Curve25519", "xz")
+ return get_params("other", "Curve25519", "xz", infty=False)
+
+
+@pytest.fixture(scope="session")
+def curve448() -> DomainParameters:
+ return get_params("other", "Curve448", "xz", infty=False)
@pytest.fixture(scope="session")
diff --git a/test/ec/test_key_agreement.py b/test/ec/test_key_agreement.py
index 60e48f8..4afb2de 100644
--- a/test/ec/test_key_agreement.py
+++ b/test/ec/test_key_agreement.py
@@ -12,9 +12,16 @@ from pyecsca.ec.key_agreement import (
ECDH_SHA256,
ECDH_SHA384,
ECDH_SHA512,
+ X25519,
+ X448,
)
from pyecsca.ec.mod import Mod, mod
-from pyecsca.ec.mult import LTRMultiplier
+from pyecsca.ec.mult import (
+ LTRMultiplier,
+ LadderMultiplier,
+ SwapLadderMultiplier,
+ DifferentialLadderMultiplier,
+)
import test.data.ec
from pyecsca.ec.params import get_params
from pyecsca.ec.point import Point
@@ -43,7 +50,9 @@ def keypair_b(secp128r1, mult):
return priv_b, pub_b
-@pytest.mark.parametrize("algo", [ECDH_NONE, ECDH_SHA1, ECDH_SHA224, ECDH_SHA256, ECDH_SHA384, ECDH_SHA512])
+@pytest.mark.parametrize(
+ "algo", [ECDH_NONE, ECDH_SHA1, ECDH_SHA224, ECDH_SHA256, ECDH_SHA384, ECDH_SHA512]
+)
def test_ka(algo, mult, secp128r1, keypair_a, keypair_b):
result_ab = algo(mult, secp128r1, keypair_a[1], keypair_b[0]).perform()
result_ba = algo(mult, secp128r1, keypair_b[1], keypair_a[0]).perform()
@@ -59,14 +68,18 @@ def test_ka_secg():
dbl = secp160r1.curve.coordinate_model.formulas["dbl-2015-rcb"]
mult = LTRMultiplier(add, dbl)
privA = mod(int(secg_data["keyA"]["priv"], 16), secp160r1.order)
- pubA_affine = Point(affine_model,
- x=mod(int(secg_data["keyA"]["pub"]["x"], 16), secp160r1.curve.prime),
- y=mod(int(secg_data["keyA"]["pub"]["y"], 16), secp160r1.curve.prime))
+ pubA_affine = Point(
+ affine_model,
+ x=mod(int(secg_data["keyA"]["pub"]["x"], 16), secp160r1.curve.prime),
+ y=mod(int(secg_data["keyA"]["pub"]["y"], 16), secp160r1.curve.prime),
+ )
pubA = pubA_affine.to_model(secp160r1.curve.coordinate_model, secp160r1.curve)
privB = mod(int(secg_data["keyB"]["priv"], 16), secp160r1.order)
- pubB_affine = Point(affine_model,
- x=mod(int(secg_data["keyB"]["pub"]["x"], 16), secp160r1.curve.prime),
- y=mod(int(secg_data["keyB"]["pub"]["y"], 16), secp160r1.curve.prime))
+ pubB_affine = Point(
+ affine_model,
+ x=mod(int(secg_data["keyB"]["pub"]["x"], 16), secp160r1.curve.prime),
+ y=mod(int(secg_data["keyB"]["pub"]["y"], 16), secp160r1.curve.prime),
+ )
pubB = pubB_affine.to_model(secp160r1.curve.coordinate_model, secp160r1.curve)
algoAB = ECDH_SHA1(copy(mult), secp160r1, pubA, privB)
@@ -83,3 +96,121 @@ def test_ka_secg():
n = (p.bit_length() + 7) // 8
result = x.to_bytes(n, byteorder="big")
assert result == bytes.fromhex(secg_data["raw"])
+
+
+@pytest.mark.parametrize(
+ "mult_args",
+ [
+ (LadderMultiplier, "ladd-1987-m", "dbl-1987-m", "scale"),
+ (SwapLadderMultiplier, "ladd-1987-m", "dbl-1987-m", "scale"),
+ (DifferentialLadderMultiplier, "dadd-1987-m", "dbl-1987-m", "scale"),
+ ],
+)
+@pytest.mark.parametrize("complete", [True, False])
+@pytest.mark.parametrize("short_circuit", [True, False])
+@pytest.mark.parametrize(
+ "scalar_hex,coord_hex,result_hex",
+ [
+ (
+ "A546E36BF0527C9D3B16154B82465EDD62144C0AC1FC5A18506A2244BA449AC4",
+ "E6DB6867583030DB3594C1A424B15F7C726624EC26B3353B10A903A6D0AB1C4C",
+ "C3DA55379DE9C6908E94EA4DF28D084F32ECCF03491C71F754B4075577A28552",
+ ),
+ (
+ "4b66e9d4d1b4673c5ad22691957d6af5c11b6421e0ea01d42ca4169e7918ba0d",
+ "e5210f12786811d3f4b7959d0538ae2c31dbe7106fc03c3efc4cd549c715a493",
+ "95cbde9476e8907d7aade45cb4b873f88b595a68799fa152e6f8f7647aac7957",
+ ),
+ (
+ "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a",
+ "de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f",
+ "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742",
+ ),
+ (
+ "5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb",
+ "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a",
+ "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742",
+ ),
+ ],
+ ids=["RFC7748tv1", "RFC7748tv2", "RFC7748dh1", "RFC7748dh2"],
+)
+def test_x25519(
+ curve25519, mult_args, complete, short_circuit, scalar_hex, coord_hex, result_hex
+):
+ mult_class = mult_args[0]
+ mult_formulas = list(
+ map(
+ lambda name: curve25519.curve.coordinate_model.formulas[name], mult_args[1:]
+ )
+ )
+ multiplier = mult_class(
+ *mult_formulas, complete=complete, short_circuit=short_circuit
+ )
+
+ scalar = int.from_bytes(bytes.fromhex(scalar_hex), "little")
+ coord = int.from_bytes(bytes.fromhex(coord_hex), "little")
+ result = bytes.fromhex(result_hex)
+ p = curve25519.curve.prime
+ coord &= (1 << 255) - 1
+ point = Point(curve25519.curve.coordinate_model, X=mod(coord, p), Z=mod(1, p))
+ xdh = X25519(multiplier, point, scalar)
+ res = xdh.perform()
+ assert res == result
+
+
+@pytest.mark.parametrize(
+ "mult_args",
+ [
+ (LadderMultiplier, "ladd-1987-m", "dbl-1987-m", "scale"),
+ (SwapLadderMultiplier, "ladd-1987-m", "dbl-1987-m", "scale"),
+ (DifferentialLadderMultiplier, "dadd-1987-m", "dbl-1987-m", "scale"),
+ ],
+)
+@pytest.mark.parametrize("complete", [True, False])
+@pytest.mark.parametrize("short_circuit", [True, False])
+@pytest.mark.parametrize(
+ "scalar_hex,coord_hex,result_hex",
+ [
+ (
+ "3d262fddf9ec8e88495266fea19a34d28882acef045104d0d1aae121700a779c984c24f8cdd78fbff44943eba368f54b29259a4f1c600ad3",
+ "06fce640fa3487bfda5f6cf2d5263f8aad88334cbd07437f020f08f9814dc031ddbdc38c19c6da2583fa5429db94ada18aa7a7fb4ef8a086",
+ "ce3e4ff95a60dc6697da1db1d85e6afbdf79b50a2412d7546d5f239fe14fbaadeb445fc66a01b0779d98223961111e21766282f73dd96b6f",
+ ),
+ (
+ "203d494428b8399352665ddca42f9de8fef600908e0d461cb021f8c538345dd77c3e4806e25f46d3315c44e0a5b4371282dd2c8d5be3095f",
+ "0fbcc2f993cd56d3305b0b7d9e55d4c1a8fb5dbb52f8e9a1e9b6201b165d015894e56c4d3570bee52fe205e28a78b91cdfbde71ce8d157db",
+ "884a02576239ff7a2f2f63b2db6a9ff37047ac13568e1e30fe63c4a7ad1b3ee3a5700df34321d62077e63633c575c1c954514e99da7c179d",
+ ),
+ (
+ "9a8f4925d1519f5775cf46b04b5800d4ee9ee8bae8bc5565d498c28dd9c9baf574a9419744897391006382a6f127ab1d9ac2d8c0a598726b",
+ "3eb7a829b0cd20f5bcfc0b599b6feccf6da4627107bdb0d4f345b43027d8b972fc3e34fb4232a13ca706dcb57aec3dae07bdc1c67bf33609",
+ "07fff4181ac6cc95ec1c16a94a0f74d12da232ce40a77552281d282bb60c0b56fd2464c335543936521c24403085d59a449a5037514a879d",
+ ),
+ (
+ "1c306a7ac2a0e2e0990b294470cba339e6453772b075811d8fad0d1d6927c120bb5ee8972b0d3e21374c9c921b09d1b0366f10b65173992d",
+ "9b08f7cc31b7e3e67d22d5aea121074a273bd2b83de09c63faa73d2c22c5d9bbc836647241d953d40c5b12da88120d53177f80e532c41fa0",
+ "07fff4181ac6cc95ec1c16a94a0f74d12da232ce40a77552281d282bb60c0b56fd2464c335543936521c24403085d59a449a5037514a879d",
+ ),
+ ],
+ ids=["RFC7748tv1", "RFC7748tv2", "RFC7748dh1", "RFC7748dh2"],
+)
+def test_x448(
+ curve448, mult_args, complete, short_circuit, scalar_hex, coord_hex, result_hex
+):
+ mult_class = mult_args[0]
+ mult_formulas = list(
+ map(lambda name: curve448.curve.coordinate_model.formulas[name], mult_args[1:])
+ )
+ multiplier = mult_class(
+ *mult_formulas, complete=complete, short_circuit=short_circuit
+ )
+
+ scalar = int.from_bytes(bytes.fromhex(scalar_hex), "little")
+ coord = int.from_bytes(bytes.fromhex(coord_hex), "little")
+ result = bytes.fromhex(result_hex)
+ p = curve448.curve.prime
+
+ point = Point(curve448.curve.coordinate_model, X=mod(coord, p), Z=mod(1, p))
+ xdh = X448(multiplier, point, scalar)
+ res = xdh.perform()
+ assert res == result
diff --git a/test/ec/test_mult.py b/test/ec/test_mult.py
index e477b07..e65dfdd 100644
--- a/test/ec/test_mult.py
+++ b/test/ec/test_mult.py
@@ -3,6 +3,7 @@ from typing import Sequence, List
import pytest
+from pyecsca.ec.context import local, DefaultContext
from pyecsca.ec.mod import Mod, mod
from pyecsca.ec.mult import (
DoubleAndAddMultiplier,
@@ -22,9 +23,11 @@ from pyecsca.ec.mult import (
BGMWMultiplier,
CombMultiplier,
WindowBoothMultiplier,
+ SwapLadderMultiplier,
)
from pyecsca.ec.mult.fixed import FullPrecompMultiplier
from pyecsca.ec.point import InfinityPoint, Point
+from pyecsca.sca import MultipleContext
def get_formulas(coords, *names):
@@ -54,7 +57,7 @@ def do_basic_test(mult_class, params, base, add, dbl, scale, neg=None, **kwargs)
except NotImplementedError:
pass
mult.init(params, base)
- assert InfinityPoint(params.curve.coordinate_model) == mult.multiply(0)
+ assert params.curve.neutral == mult.multiply(0)
return res
@@ -186,7 +189,7 @@ def test_ladder(curve25519):
29893438142586401087946310744922998080771935139441267052026283852717044358472,
48084050389777770101701157326923977117307187144965043058462938058489685090437,
40694087602335028385342029955981451169449898924211721351135404099078471497195,
- )
+ ),
],
)
def test_ladder_full(curve25519, scalar, x, res):
@@ -194,24 +197,19 @@ def test_ladder_full(curve25519, scalar, x, res):
point = Point(curve25519.curve.coordinate_model, X=mod(x, p), Z=mod(1, p))
result = Point(curve25519.curve.coordinate_model, X=mod(res, p), Z=mod(1, p))
- mult = LadderMultiplier(
- curve25519.curve.coordinate_model.formulas["ladd-1987-m"],
- curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
- # complete=False
- )
- fixed = int(mod(scalar, curve25519.order))
-
- mult.init(curve25519, point)
- computed = mult.multiply(fixed)
+ for complete in (True, False):
+ mult = LadderMultiplier(
+ curve25519.curve.coordinate_model.formulas["ladd-1987-m"],
+ curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
+ complete=complete,
+ )
- point_aff = list(curve25519.curve.affine_lift_x(mod(x, p)))[0]
- result_aff = list(curve25519.curve.affine_lift_x(mod(res, p)))[0]
- computed_aff = curve25519.curve.affine_multiply(point_aff, scalar)
+ mult.init(curve25519, point)
+ computed = mult.multiply(scalar)
- scale = curve25519.curve.coordinate_model.formulas["scale"]
- converted = scale(p, computed, **curve25519.curve.parameters)[0]
- assert computed_aff.x == result_aff.x
- assert converted.X == result.X
+ scale = curve25519.curve.coordinate_model.formulas["scale"]
+ converted = scale(p, computed, **curve25519.curve.parameters)[0]
+ assert converted.X == result.X
@pytest.mark.parametrize(
@@ -229,35 +227,71 @@ def test_simple_ladder(secp128r1, add, dbl, scale):
@pytest.mark.parametrize(
- "num,complete",
+ "num",
+ [
+ 15,
+ 2355498743,
+ 325385790209017329644351321912443757746,
+ 0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED - 1,
+ ],
+)
+@pytest.mark.parametrize("complete", [True, False])
+@pytest.mark.parametrize("short_circuit", [True, False])
+def test_ladder_swap(curve25519, num, complete, short_circuit):
+ ladder = LadderMultiplier(
+ curve25519.curve.coordinate_model.formulas["ladd-1987-m"],
+ curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
+ curve25519.curve.coordinate_model.formulas["scale"],
+ complete=complete,
+ short_circuit=short_circuit,
+ )
+ swap = SwapLadderMultiplier(
+ curve25519.curve.coordinate_model.formulas["ladd-1987-m"],
+ curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
+ curve25519.curve.coordinate_model.formulas["scale"],
+ complete=complete,
+ short_circuit=short_circuit,
+ )
+ ladder.init(curve25519, curve25519.generator)
+ res_ladder = ladder.multiply(num)
+ swap.init(curve25519, curve25519.generator)
+ res_swap = swap.multiply(num)
+ assert res_ladder == res_swap
+ assert curve25519.curve.neutral == swap.multiply(0)
+
+
+@pytest.mark.parametrize(
+ "num",
[
- (15, True),
- (15, False),
- (2355498743, True),
- (2355498743, False),
- (325385790209017329644351321912443757746, True),
- (325385790209017329644351321912443757746, False),
+ 15,
+ 2355498743,
+ 325385790209017329644351321912443757746,
+ 0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED - 1,
],
)
-def test_ladder_differential(curve25519, num, complete):
+@pytest.mark.parametrize("complete", [True, False])
+@pytest.mark.parametrize("short_circuit", [True, False])
+def test_ladder_differential(curve25519, num, complete, short_circuit):
ladder = LadderMultiplier(
curve25519.curve.coordinate_model.formulas["ladd-1987-m"],
curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
curve25519.curve.coordinate_model.formulas["scale"],
complete=complete,
+ short_circuit=short_circuit,
)
differential = DifferentialLadderMultiplier(
curve25519.curve.coordinate_model.formulas["dadd-1987-m"],
curve25519.curve.coordinate_model.formulas["dbl-1987-m"],
curve25519.curve.coordinate_model.formulas["scale"],
complete=complete,
+ short_circuit=short_circuit,
)
ladder.init(curve25519, curve25519.generator)
res_ladder = ladder.multiply(num)
differential.init(curve25519, curve25519.generator)
res_differential = differential.multiply(num)
assert res_ladder == res_differential
- assert InfinityPoint(curve25519.curve.coordinate_model) == differential.multiply(0)
+ assert curve25519.curve.neutral == differential.multiply(0)
@pytest.mark.parametrize(
diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py
index fe918ec..f3e87cd 100644
--- a/test/sca/test_tree.py
+++ b/test/sca/test_tree.py
@@ -118,7 +118,6 @@ def test_build_tree_dedup():
tree = Tree.build(cfgs, original)
dedup = Tree.build(cfgs, dmap)
dedup_other = Tree.build(cfgs, deduplicated)
- print(tree.describe())
assert tree.describe() == dedup.describe()
assert tree.describe() == dedup_other.describe()