1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
|
import subprocess
import json
from tempfile import NamedTemporaryFile
from typing import Generator, Any
from click.testing import CliRunner
import pytest
from importlib import resources
from os.path import join
from pyecsca.codegen.builder import build_impl
from pyecsca.ec.formula import FormulaAction, NegationFormula
from pyecsca.ec.model import CurveModel
from pyecsca.ec.coordinates import CoordinateModel
from pyecsca.sca.target.binary import BinaryTarget
from pyecsca.codegen.client import ImplTarget
from pyecsca.ec.context import DefaultContext, local, Node
class GDBTarget(ImplTarget, BinaryTarget):
def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs):
super().__init__(model, coords, **kwargs)
self.trace_file = None
def connect(self):
self.trace_file = NamedTemporaryFile("r+")
with resources.path("test", "gdb_script.py") as gdb_script:
self.process = subprocess.Popen(
[
"gdb",
"--q",
"--batch-silent",
"-x",
gdb_script,
"--args",
*self.binary,
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env={
"TRACE_FILE": self.trace_file.name,
},
text=True,
bufsize=1,
)
def disconnect(self):
super().disconnect()
if self.trace_file is not None:
# self.trace_file.close()
self.trace_file = None
@pytest.fixture(scope="module")
def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]:
mult_class, mult_kwargs = simple_multiplier
mult_name = mult_class.__name__
formulas = ["add-1998-cmo", "dbl-1998-cmo"]
if NegationFormula in mult_class.requires:
formulas.append("neg")
runner = CliRunner()
with runner.isolated_filesystem() as tmpdir:
res = runner.invoke(
build_impl,
[
"--platform",
"HOST",
"--ecdsa",
"--ecdh",
secp128r1.curve.model.shortname,
secp128r1.curve.coordinate_model.name,
*formulas,
f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})",
".",
],
env={"DEBUG": "1", "CFLAGS": "-g -O0"},
)
assert res.exit_code == 0
target = GDBTarget(
secp128r1.curve.model,
secp128r1.curve.coordinate_model,
binary=join(tmpdir, "pyecsca-codegen-HOST.elf"),
)
formula_instances = [
secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas
]
mult = mult_class(*formula_instances, **mult_kwargs)
target.mult = mult
yield target
def parse_trace(captured: str):
current_function = None
args = []
rets = []
result = []
for line in captured.split("\n"):
if ":" not in line:
func = line.strip()
if func.startswith("point_"):
func = func[len("point_") :]
if func == "set":
# The sets that happen inside another formula (like add) are a sign of short-circuiting.
# The Python simulation does not record the short-circuits, so we ignore them here.
current_function = None
else:
if current_function is not None:
result.append((current_function, args, rets))
current_function = func
args = []
rets = []
else:
name, data = line.split(":", 1)
name = name.strip()
value = json.loads(data)
if "out" in name:
rets.append(value)
else:
args.append(value)
return result
def parse_ctx(scalarmult: Node):
result = []
for node in scalarmult.children:
action: FormulaAction = node.action
formula = action.formula
name = formula.shortname
args = []
for point in action.input_points:
point_value = {k: int(v) for k, v in point.coords.items()}
args.append(point_value)
rets = []
for point in action.output_points:
point_value = {k: int(v) for k, v in point.coords.items()}
rets.append(point_value)
result.append((name, args, rets))
return result
def make_hashable(trace):
result = []
for entry in trace:
name, args, rets = entry
args_t = tuple(tuple(arg.items()) for arg in args)
rets_t = tuple(tuple(ret.items()) for ret in rets)
result.append((name, args_t, rets_t))
return tuple(result)
def test_equivalence(target, secp128r1):
mult = target.mult
for i in range(10):
target.connect()
target.init_prng(i.to_bytes(4, "big"))
target.set_params(secp128r1)
priv, pub = target.generate()
with local(DefaultContext()) as ctx:
mult.init(secp128r1, secp128r1.generator)
expected = mult.multiply(priv).to_affine()
assert secp128r1.curve.is_on_curve(pub)
assert pub == expected
err = target.trace_file.read()
print(err)
from_codegen = parse_trace(err)
from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
codegen_set = set(make_hashable(from_codegen))
sim_set = set(make_hashable(from_sim))
if codegen_set != sim_set and False:
print(len(from_codegen), len(from_sim))
print("In codegen but not in sim:")
for entry in codegen_set - sim_set:
print(entry)
print("In sim but not in codegen:")
for entry in sim_set - codegen_set:
print(entry)
assert from_codegen == from_sim
target.quit()
target.disconnect()
|