aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/key_agreement.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec/key_agreement.py')
-rw-r--r--pyecsca/ec/key_agreement.py150
1 files changed, 131 insertions, 19 deletions
diff --git a/pyecsca/ec/key_agreement.py b/pyecsca/ec/key_agreement.py
index fe07b1f..c15c0b6 100644
--- a/pyecsca/ec/key_agreement.py
+++ b/pyecsca/ec/key_agreement.py
@@ -1,13 +1,16 @@
-"""Provides an implementation of ECDH (Elliptic Curve Diffie-Hellman)."""
+"""Provides an implementation of ECDH (Elliptic Curve Diffie-Hellman) and XDH (X25519, X448)."""
+
import hashlib
+from abc import abstractmethod, ABC
from typing import Optional, Any
from public import public
from pyecsca.ec.context import ResultAction
from pyecsca.ec.mod import Mod
+from pyecsca.ec.model import MontgomeryModel
from pyecsca.ec.mult import ScalarMultiplier
-from pyecsca.ec.params import DomainParameters
+from pyecsca.ec.params import DomainParameters, get_params
from pyecsca.ec.point import Point
@@ -38,7 +41,126 @@ class ECDHAction(ResultAction):
@public
-class KeyAgreement:
+class XDHAction(ResultAction):
+ """XDH key exchange."""
+
+ params: DomainParameters
+ privkey: int
+ pubkey: Point
+
+ def __init__(self, params: DomainParameters, privkey: int, pubkey: Point):
+ super().__init__()
+ self.params = params
+ self.privkey = privkey
+ self.pubkey = pubkey
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.params}, {self.privkey}, {self.pubkey})"
+ )
+
+
+@public
+class KeyAgreement(ABC):
+ """An abstract EC-based key agreement."""
+
+ @abstractmethod
+ def perform_raw(self) -> Point:
+ """
+ Perform the scalar-multiplication of the key agreement.
+
+ :return: The shared point.
+ """
+ ...
+
+ @abstractmethod
+ def perform(self) -> bytes:
+ """
+ Perform the key agreement operation.
+
+ :return: The shared secret.
+ """
+ ...
+
+
+@public
+class XDH(KeyAgreement):
+ def __init__(
+ self,
+ mult: ScalarMultiplier,
+ params: DomainParameters,
+ pubkey: Point,
+ privkey: int,
+ bits: int,
+ bytes: int,
+ ):
+ if "scl" not in mult.formulas:
+ raise ValueError("ScalarMultiplier needs to have the scaling formula.")
+ if not isinstance(params.curve.model, MontgomeryModel):
+ raise ValueError("Invalid curve model.")
+ self.mult = mult
+ self.params = params
+ self.pubkey = pubkey
+ self.privkey = privkey
+ self.bits = bits
+ self.bytes = bytes
+ self.mult.init(self.params, self.pubkey)
+
+ def clamp(self, scalar: int) -> int:
+ return scalar
+
+ def perform_raw(self) -> Point:
+ clamped = self.clamp(self.privkey)
+ return self.mult.multiply(clamped)
+
+ def perform(self) -> bytes:
+ with XDHAction(self.params, self.privkey, self.pubkey) as action:
+ point = self.perform_raw()
+ return action.exit(int(point.X).to_bytes(self.bytes, "little"))
+
+
+@public
+class X25519(XDH):
+ """
+ X25519 (or Curve25519) from [RFC7748]_.
+
+ .. warning::
+ You need to clear the top bit of the point coordinate (pubkey) before converting to a point.
+ """
+
+ def __init__(self, mult: ScalarMultiplier, pubkey: Point, privkey: int):
+ curve25519 = get_params(
+ "other", "Curve25519", pubkey.coordinate_model.name, infty=False
+ )
+ super().__init__(mult, curve25519, pubkey, privkey, 255, 32)
+
+ def clamp(self, scalar: int) -> int:
+ scalar &= ~7
+ scalar &= ~(128 << 8 * 31)
+ scalar |= 64 << 8 * 31
+ return scalar
+
+
+@public
+class X448(XDH):
+ """
+ X448 (or Curve448) from [RFC7748]_.
+ """
+
+ def __init__(self, mult: ScalarMultiplier, pubkey: Point, privkey: int):
+ curve448 = get_params(
+ "other", "Curve448", pubkey.coordinate_model.name, infty=False
+ )
+ super().__init__(mult, curve448, pubkey, privkey, 448, 56)
+
+ def clamp(self, scalar: int) -> int:
+ scalar &= ~3
+ scalar |= 128 << 8 * 55
+ return scalar
+
+
+@public
+class ECDH(KeyAgreement):
"""EC based key agreement primitive (ECDH)."""
mult: ScalarMultiplier
@@ -63,20 +185,10 @@ class KeyAgreement:
self.mult.init(self.params, self.pubkey)
def perform_raw(self) -> Point:
- """
- Perform the scalar-multiplication of the key agreement.
-
- :return: The shared point.
- """
point = self.mult.multiply(int(self.privkey))
return point.to_affine()
def perform(self) -> bytes:
- """
- Perform the key agreement operation.
-
- :return: The shared secret.
- """
with ECDHAction(
self.params, self.hash_algo, self.privkey, self.pubkey
) as action:
@@ -91,7 +203,7 @@ class KeyAgreement:
@public
-class ECDH_NONE(KeyAgreement):
+class ECDH_NONE(ECDH):
"""Raw x-coordinate ECDH."""
def __init__(
@@ -105,7 +217,7 @@ class ECDH_NONE(KeyAgreement):
@public
-class ECDH_SHA1(KeyAgreement):
+class ECDH_SHA1(ECDH):
"""ECDH with SHA1 of x-coordinate."""
def __init__(
@@ -119,7 +231,7 @@ class ECDH_SHA1(KeyAgreement):
@public
-class ECDH_SHA224(KeyAgreement):
+class ECDH_SHA224(ECDH):
"""ECDH with SHA224 of x-coordinate."""
def __init__(
@@ -133,7 +245,7 @@ class ECDH_SHA224(KeyAgreement):
@public
-class ECDH_SHA256(KeyAgreement):
+class ECDH_SHA256(ECDH):
"""ECDH with SHA256 of x-coordinate."""
def __init__(
@@ -147,7 +259,7 @@ class ECDH_SHA256(KeyAgreement):
@public
-class ECDH_SHA384(KeyAgreement):
+class ECDH_SHA384(ECDH):
"""ECDH with SHA384 of x-coordinate."""
def __init__(
@@ -161,7 +273,7 @@ class ECDH_SHA384(KeyAgreement):
@public
-class ECDH_SHA512(KeyAgreement):
+class ECDH_SHA512(ECDH):
"""ECDH with SHA512 of x-coordinate."""
def __init__(