aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/test_equivalence.py
blob: 5a4872914a3ca3d650424f56501157a0a7f2d79f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import subprocess
import json
from tempfile import NamedTemporaryFile
from typing import Generator, Any
from click.testing import CliRunner

import pytest
from importlib import resources
from os.path import join

from pyecsca.codegen.builder import build_impl
from pyecsca.ec.formula import FormulaAction, NegationFormula
from pyecsca.ec.model import CurveModel
from pyecsca.ec.coordinates import CoordinateModel
from pyecsca.sca.target.binary import BinaryTarget
from pyecsca.codegen.client import ImplTarget
from pyecsca.ec.context import DefaultContext, local, Node


class GDBTarget(ImplTarget, BinaryTarget):
    def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs):
        super().__init__(model, coords, **kwargs)
        self.trace_file = None

    def connect(self):
        self.trace_file = NamedTemporaryFile("r+")
        with resources.path("test", "gdb_script.py") as gdb_script:
            self.process = subprocess.Popen(
                [
                    "gdb",
                    "--q",
                    "--batch-silent",
                    "-x",
                    gdb_script,
                    "--args",
                    *self.binary,
                ],
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.DEVNULL,
                env={
                    "TRACE_FILE": self.trace_file.name,
                },
                text=True,
                bufsize=1,
            )

    def disconnect(self):
        super().disconnect()
        if self.trace_file is not None:
            # self.trace_file.close()
            self.trace_file = None


@pytest.fixture(scope="module")
def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]:
    mult_class, mult_kwargs = simple_multiplier
    mult_name = mult_class.__name__
    formulas = ["add-1998-cmo", "dbl-1998-cmo"]
    if NegationFormula in mult_class.requires:
        formulas.append("neg")
    runner = CliRunner()
    with runner.isolated_filesystem() as tmpdir:
        res = runner.invoke(
            build_impl,
            [
                "--platform",
                "HOST",
                "--ecdsa",
                "--ecdh",
                secp128r1.curve.model.shortname,
                secp128r1.curve.coordinate_model.name,
                *formulas,
                f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})",
                ".",
            ],
            env={"DEBUG": "1", "CFLAGS": "-g -O0"},
        )
        assert res.exit_code == 0
        target = GDBTarget(
            secp128r1.curve.model,
            secp128r1.curve.coordinate_model,
            binary=join(tmpdir, "pyecsca-codegen-HOST.elf"),
        )
        formula_instances = [
            secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas
        ]
        mult = mult_class(*formula_instances, **mult_kwargs)
        target.mult = mult
        yield target


def parse_trace(captured: str):
    current_function = None
    args = []
    rets = []
    result = []
    for line in captured.split("\n"):
        if ":" not in line:
            func = line.strip()
            if func.startswith("point_"):
                func = func[len("point_") :]
            if func == "set":
                # The sets that happen inside another formula (like add) are a sign of short-circuiting.
                # The Python simulation does not record the short-circuits, so we ignore them here.
                current_function = None
            else:
                if current_function is not None:
                    result.append((current_function, args, rets))
                current_function = func
            args = []
            rets = []
        else:
            name, data = line.split(":", 1)
            name = name.strip()
            value = json.loads(data)
            if "out" in name:
                rets.append(value)
            else:
                args.append(value)
    return result


def parse_ctx(scalarmult: Node):
    result = []
    for node in scalarmult.children:
        action: FormulaAction = node.action
        formula = action.formula
        name = formula.shortname
        args = []
        for point in action.input_points:
            point_value = {k: int(v) for k, v in point.coords.items()}
            args.append(point_value)
        rets = []
        for point in action.output_points:
            point_value = {k: int(v) for k, v in point.coords.items()}
            rets.append(point_value)
        result.append((name, args, rets))
    return result


def make_hashable(trace):
    result = []
    for entry in trace:
        name, args, rets = entry
        args_t = tuple(tuple(arg.items()) for arg in args)
        rets_t = tuple(tuple(ret.items()) for ret in rets)
        result.append((name, args_t, rets_t))
    return tuple(result)


def test_equivalence(target, secp128r1):
    mult = target.mult
    for i in range(10):
        target.connect()
        target.init_prng(i.to_bytes(4, "big"))
        target.set_params(secp128r1)

        priv, pub = target.generate()

        with local(DefaultContext()) as ctx:
            mult.init(secp128r1, secp128r1.generator)
            expected = mult.multiply(priv).to_affine()

        assert secp128r1.curve.is_on_curve(pub)
        assert pub == expected
        err = target.trace_file.read()
        print(err)
        from_codegen = parse_trace(err)
        from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
        codegen_set = set(make_hashable(from_codegen))
        sim_set = set(make_hashable(from_sim))
        if codegen_set != sim_set and False:
            print(len(from_codegen), len(from_sim))
            print("In codegen but not in sim:")
            for entry in codegen_set - sim_set:
                print(entry)
            print("In sim but not in codegen:")
            for entry in sim_set - codegen_set:
                print(entry)
        assert from_codegen == from_sim

        target.quit()
        target.disconnect()