diff options
Diffstat (limited to 'test/ec/test_key_agreement.py')
| -rw-r--r-- | test/ec/test_key_agreement.py | 60 |
1 files changed, 29 insertions, 31 deletions
diff --git a/test/ec/test_key_agreement.py b/test/ec/test_key_agreement.py index 240c174..14bd138 100644 --- a/test/ec/test_key_agreement.py +++ b/test/ec/test_key_agreement.py @@ -1,8 +1,4 @@ -from unittest import TestCase - -from parameterized import parameterized - -from pyecsca.ec.params import get_params +import pytest from pyecsca.ec.key_agreement import ( ECDH_NONE, ECDH_SHA1, @@ -15,31 +11,33 @@ from pyecsca.ec.mod import Mod from pyecsca.ec.mult import LTRMultiplier -class KeyAgreementTests(TestCase): - def setUp(self): - self.secp128r1 = get_params("secg", "secp128r1", "projective") - self.add = self.secp128r1.curve.coordinate_model.formulas["add-2007-bl"] - self.dbl = self.secp128r1.curve.coordinate_model.formulas["dbl-2007-bl"] - self.mult = LTRMultiplier(self.add, self.dbl) - self.priv_a = Mod(0xDEADBEEF, self.secp128r1.order) - self.mult.init(self.secp128r1, self.secp128r1.generator) - self.pub_a = self.mult.multiply(int(self.priv_a)) - self.priv_b = Mod(0xCAFEBABE, self.secp128r1.order) - self.pub_b = self.mult.multiply(int(self.priv_b)) +@pytest.fixture() +def mult(secp128r1): + add = secp128r1.curve.coordinate_model.formulas["add-2007-bl"] + dbl = secp128r1.curve.coordinate_model.formulas["dbl-2007-bl"] + return LTRMultiplier(add, dbl) + + +@pytest.fixture() +def keypair_a(secp128r1, mult): + priv_a = Mod(0xDEADBEEF, secp128r1.order) + mult.init(secp128r1, secp128r1.generator) + pub_a = mult.multiply(int(priv_a)) + return priv_a, pub_a + + +@pytest.fixture() +def keypair_b(secp128r1, mult): + priv_b = Mod(0xCAFEBABE, secp128r1.order) + mult.init(secp128r1, secp128r1.generator) + pub_b = mult.multiply(int(priv_b)) + return priv_b, pub_b + - @parameterized.expand( - [ - ("NONE", ECDH_NONE), - ("SHA1", ECDH_SHA1), - ("SHA224", ECDH_SHA224), - ("SHA256", ECDH_SHA256), - ("SHA384", ECDH_SHA384), - ("SHA512", ECDH_SHA512), - ] - ) - def test_all(self, name, algo): - result_ab = algo(self.mult, self.secp128r1, self.pub_a, self.priv_b).perform() - result_ba = algo(self.mult, self.secp128r1, self.pub_b, self.priv_a).perform() - self.assertEqual(result_ab, result_ba) +@pytest.mark.parametrize("algo", [ECDH_NONE, ECDH_SHA1, ECDH_SHA224, ECDH_SHA256, ECDH_SHA384, ECDH_SHA512]) +def test_ka(algo, mult, secp128r1, keypair_a, keypair_b): + result_ab = algo(mult, secp128r1, keypair_a[1], keypair_b[0]).perform() + result_ba = algo(mult, secp128r1, keypair_b[1], keypair_a[0]).perform() + assert result_ab == result_ba - # TODO: Add KAT-based tests here. +# TODO: Add KAT-based tests here. |
