aboutsummaryrefslogtreecommitdiff
path: root/test/sca/test_rpa_context.py
blob: c50b42036355e8849ff3d623f317899ff7d5ca3e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import cast

import pytest

from pyecsca.ec.context import local
from pyecsca.ec.formula import (
    LadderFormula,
    DifferentialAdditionFormula,
    DoublingFormula,
    ScalingFormula,
)
from pyecsca.ec.mult import (
    LTRMultiplier,
    BinaryNAFMultiplier,
    WindowNAFMultiplier,
    LadderMultiplier,
    DifferentialLadderMultiplier,
)
from pyecsca.sca.re.rpa import MultipleContext


@pytest.fixture()
def add(secp128r1):
    return secp128r1.curve.coordinate_model.formulas["add-1998-cmo"]


@pytest.fixture()
def dbl(secp128r1):
    return secp128r1.curve.coordinate_model.formulas["dbl-1998-cmo"]


@pytest.fixture()
def neg(secp128r1):
    return secp128r1.curve.coordinate_model.formulas["neg"]


@pytest.fixture()
def scale(secp128r1):
    return secp128r1.curve.coordinate_model.formulas["z"]


@pytest.mark.parametrize(
    "name,scalar",
    [
        ("5", 5),
        ("10", 10),
        ("2355498743", 2355498743),
        (
            "325385790209017329644351321912443757746",
            325385790209017329644351321912443757746,
        ),
        ("13613624287328732", 13613624287328732),
    ],
)
def test_basic(secp128r1, add, dbl, scale, name, scalar):
    mult = LTRMultiplier(
        add,
        dbl,
        scale,
        always=False,
        complete=False,
        short_circuit=True,
    )
    with local(MultipleContext()) as ctx:
        mult.init(secp128r1, secp128r1.generator)
        mult.multiply(scalar)
    muls = list(ctx.points.values())
    assert muls[-1] == scalar


def test_precomp(secp128r1, add, dbl, neg, scale):
    bnaf = BinaryNAFMultiplier(add, dbl, neg, scale)
    with local(MultipleContext()) as ctx:
        bnaf.init(secp128r1, secp128r1.generator)
    muls = list(ctx.points.values())
    assert muls == [1, 0, -1]

    wnaf = WindowNAFMultiplier(add, dbl, neg, 3, scale)
    with local(MultipleContext()) as ctx:
        wnaf.init(secp128r1, secp128r1.generator)
    muls = list(ctx.points.values())
    assert muls == [1, 0, 2, 3, 5]


def test_window(secp128r1, add, dbl, neg):
    mult = WindowNAFMultiplier(add, dbl, neg, 3, precompute_negation=True)
    with local(MultipleContext()) as ctx:
        mult.init(secp128r1, secp128r1.generator)
        mult.multiply(5)
    assert ctx.precomp


def test_ladder(curve25519):
    base = curve25519.generator
    coords = curve25519.curve.coordinate_model
    ladd = cast(LadderFormula, coords.formulas["ladd-1987-m"])
    dadd = cast(DifferentialAdditionFormula, coords.formulas["dadd-1987-m"])
    dbl = cast(DoublingFormula, coords.formulas["dbl-1987-m"])
    scale = cast(ScalingFormula, coords.formulas["scale"])

    ladd_mult = LadderMultiplier(ladd, dbl, scale)
    with local(MultipleContext()) as ctx:
        ladd_mult.init(curve25519, base)
        ladd_mult.multiply(1339278426732672313)
    muls = list(ctx.points.values())
    assert muls[-1] == 1339278426732672313
    assert muls[-3] == 1339278426732672313

    dadd_mult = DifferentialLadderMultiplier(dadd, dbl, scale)
    with local(MultipleContext()) as ctx:
        dadd_mult.init(curve25519, base)
        dadd_mult.multiply(1339278426732672313)
    muls = list(ctx.points.values())
    assert muls[-1] == 1339278426732672313
    assert muls[-3] == 1339278426732672313


def test_keep_base(secp128r1, add, dbl):
    mult = LTRMultiplier(
        add,
        dbl,
        always=False,
        complete=False,
        short_circuit=True,
    )

    with local(MultipleContext(keep_base=True)) as ctx:
        mult.init(secp128r1, secp128r1.generator)
        r = mult.multiply(5)
        mult.init(secp128r1, r)
        mult.multiply(10)
    assert 50 in ctx.points.values()