aboutsummaryrefslogtreecommitdiff
path: root/test/ec/test_key_agreement.py
blob: 240c1740948e7fda97c16c719124f629198f56ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from unittest import TestCase

from parameterized import parameterized

from pyecsca.ec.params import get_params
from pyecsca.ec.key_agreement import (
    ECDH_NONE,
    ECDH_SHA1,
    ECDH_SHA224,
    ECDH_SHA256,
    ECDH_SHA384,
    ECDH_SHA512,
)
from pyecsca.ec.mod import Mod
from pyecsca.ec.mult import LTRMultiplier


class KeyAgreementTests(TestCase):
    def setUp(self):
        self.secp128r1 = get_params("secg", "secp128r1", "projective")
        self.add = self.secp128r1.curve.coordinate_model.formulas["add-2007-bl"]
        self.dbl = self.secp128r1.curve.coordinate_model.formulas["dbl-2007-bl"]
        self.mult = LTRMultiplier(self.add, self.dbl)
        self.priv_a = Mod(0xDEADBEEF, self.secp128r1.order)
        self.mult.init(self.secp128r1, self.secp128r1.generator)
        self.pub_a = self.mult.multiply(int(self.priv_a))
        self.priv_b = Mod(0xCAFEBABE, self.secp128r1.order)
        self.pub_b = self.mult.multiply(int(self.priv_b))

    @parameterized.expand(
        [
            ("NONE", ECDH_NONE),
            ("SHA1", ECDH_SHA1),
            ("SHA224", ECDH_SHA224),
            ("SHA256", ECDH_SHA256),
            ("SHA384", ECDH_SHA384),
            ("SHA512", ECDH_SHA512),
        ]
    )
    def test_all(self, name, algo):
        result_ab = algo(self.mult, self.secp128r1, self.pub_a, self.priv_b).perform()
        result_ba = algo(self.mult, self.secp128r1, self.pub_b, self.priv_a).perform()
        self.assertEqual(result_ab, result_ba)

    # TODO: Add KAT-based tests here.