aboutsummaryrefslogtreecommitdiff
path: root/test/ec/test_key_agreement.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/ec/test_key_agreement.py')
-rw-r--r--test/ec/test_key_agreement.py60
1 files changed, 29 insertions, 31 deletions
diff --git a/test/ec/test_key_agreement.py b/test/ec/test_key_agreement.py
index 240c174..14bd138 100644
--- a/test/ec/test_key_agreement.py
+++ b/test/ec/test_key_agreement.py
@@ -1,8 +1,4 @@
-from unittest import TestCase
-
-from parameterized import parameterized
-
-from pyecsca.ec.params import get_params
+import pytest
from pyecsca.ec.key_agreement import (
ECDH_NONE,
ECDH_SHA1,
@@ -15,31 +11,33 @@ from pyecsca.ec.mod import Mod
from pyecsca.ec.mult import LTRMultiplier
-class KeyAgreementTests(TestCase):
- def setUp(self):
- self.secp128r1 = get_params("secg", "secp128r1", "projective")
- self.add = self.secp128r1.curve.coordinate_model.formulas["add-2007-bl"]
- self.dbl = self.secp128r1.curve.coordinate_model.formulas["dbl-2007-bl"]
- self.mult = LTRMultiplier(self.add, self.dbl)
- self.priv_a = Mod(0xDEADBEEF, self.secp128r1.order)
- self.mult.init(self.secp128r1, self.secp128r1.generator)
- self.pub_a = self.mult.multiply(int(self.priv_a))
- self.priv_b = Mod(0xCAFEBABE, self.secp128r1.order)
- self.pub_b = self.mult.multiply(int(self.priv_b))
+@pytest.fixture()
+def mult(secp128r1):
+ add = secp128r1.curve.coordinate_model.formulas["add-2007-bl"]
+ dbl = secp128r1.curve.coordinate_model.formulas["dbl-2007-bl"]
+ return LTRMultiplier(add, dbl)
+
+
+@pytest.fixture()
+def keypair_a(secp128r1, mult):
+ priv_a = Mod(0xDEADBEEF, secp128r1.order)
+ mult.init(secp128r1, secp128r1.generator)
+ pub_a = mult.multiply(int(priv_a))
+ return priv_a, pub_a
+
+
+@pytest.fixture()
+def keypair_b(secp128r1, mult):
+ priv_b = Mod(0xCAFEBABE, secp128r1.order)
+ mult.init(secp128r1, secp128r1.generator)
+ pub_b = mult.multiply(int(priv_b))
+ return priv_b, pub_b
+
- @parameterized.expand(
- [
- ("NONE", ECDH_NONE),
- ("SHA1", ECDH_SHA1),
- ("SHA224", ECDH_SHA224),
- ("SHA256", ECDH_SHA256),
- ("SHA384", ECDH_SHA384),
- ("SHA512", ECDH_SHA512),
- ]
- )
- def test_all(self, name, algo):
- result_ab = algo(self.mult, self.secp128r1, self.pub_a, self.priv_b).perform()
- result_ba = algo(self.mult, self.secp128r1, self.pub_b, self.priv_a).perform()
- self.assertEqual(result_ab, result_ba)
+@pytest.mark.parametrize("algo", [ECDH_NONE, ECDH_SHA1, ECDH_SHA224, ECDH_SHA256, ECDH_SHA384, ECDH_SHA512])
+def test_ka(algo, mult, secp128r1, keypair_a, keypair_b):
+ result_ab = algo(mult, secp128r1, keypair_a[1], keypair_b[0]).perform()
+ result_ba = algo(mult, secp128r1, keypair_b[1], keypair_a[0]).perform()
+ assert result_ab == result_ba
- # TODO: Add KAT-based tests here.
+# TODO: Add KAT-based tests here.