1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
|
"""
Provides functionality inspired by the Refined-Power Analysis attack by Goubin [RPA]_.
"""
from copy import copy, deepcopy
from anytree import RenderTree
from public import public
from typing import MutableMapping, Optional, Callable, List, Set
from sympy import FF, sympify, Poly, symbols
from .tree import Tree, Map
from ...ec.coordinates import AffineCoordinateModel
from ...ec.formula import (
FormulaAction,
DoublingFormula,
AdditionFormula,
TriplingFormula,
NegationFormula,
DifferentialAdditionFormula,
LadderFormula,
)
from ...ec.mod import Mod
from ...ec.mult import (
ScalarMultiplicationAction,
PrecomputationAction,
ScalarMultiplier,
)
from ...ec.params import DomainParameters
from ...ec.model import ShortWeierstrassModel, MontgomeryModel
from ...ec.point import Point
from ...ec.context import Context, Action, local
from ...misc.utils import log, warn
@public
class MultipleContext(Context):
"""Context that traces the multiples of points computed."""
base: Optional[Point]
"""The base point that all the multiples are counted from."""
points: MutableMapping[Point, int]
"""The mapping of points to the multiples they represent (e.g., base -> 1)."""
parents: MutableMapping[Point, List[Point]]
"""The mapping of points to the formula types they are a result of."""
formulas: MutableMapping[Point, str]
"""The mapping of points to their parent they were computed from."""
inside: bool
def __init__(self):
self.base = None
self.points = {}
self.parents = {}
self.formulas = {}
self.inside = False
def enter_action(self, action: Action) -> None:
if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)):
if self.base:
# If we already did some computation with this context try to see if we are building on top of it.
if self.base != action.point:
# If we are not building on top of it we have to forget stuff and set a new base and mapping.
self.base = action.point
self.points = {self.base: 1}
self.parents = {self.base: []}
self.formulas = {self.base: ""}
else:
self.base = action.point
self.points = {self.base: 1}
self.parents = {self.base: []}
self.formulas = {self.base: ""}
self.inside = True
def exit_action(self, action: Action) -> None:
if isinstance(action, (ScalarMultiplicationAction, PrecomputationAction)):
self.inside = False
if isinstance(action, FormulaAction) and self.inside:
if isinstance(action.formula, DoublingFormula):
inp = action.input_points[0]
out = action.output_points[0]
self.points[out] = 2 * self.points[inp]
self.parents[out] = [inp]
self.formulas[out] = action.formula.shortname
elif isinstance(action.formula, TriplingFormula):
inp = action.input_points[0]
out = action.output_points[0]
self.points[out] = 3 * self.points[inp]
self.parents[out] = [inp]
self.formulas[out] = action.formula.shortname
elif isinstance(action.formula, AdditionFormula):
one, other = action.input_points
out = action.output_points[0]
self.points[out] = self.points[one] + self.points[other]
self.parents[out] = [one, other]
self.formulas[out] = action.formula.shortname
elif isinstance(action.formula, NegationFormula):
inp = action.input_points[0]
out = action.output_points[0]
self.points[out] = -self.points[inp]
self.parents[out] = [inp]
self.formulas[out] = action.formula.shortname
elif isinstance(action.formula, DifferentialAdditionFormula):
_, one, other = action.input_points
out = action.output_points[0]
self.points[out] = self.points[one] + self.points[other]
self.parents[out] = [one, other]
self.formulas[out] = action.formula.shortname
elif isinstance(action.formula, LadderFormula):
_, one, other = action.input_points
dbl, add = action.output_points
self.points[add] = self.points[one] + self.points[other]
self.parents[add] = [one, other]
self.formulas[add] = action.formula.shortname
self.points[dbl] = 2 * self.points[one]
self.parents[dbl] = [one]
self.formulas[dbl] = action.formula.shortname
def __repr__(self):
return f"{self.__class__.__name__}({self.base!r}, multiples={self.points.values()!r})"
@public
def rpa_point_0y(params: DomainParameters) -> Optional[Point]:
"""Construct an (affine) [RPA]_ point (0, y) for given domain parameters."""
if isinstance(params.curve.model, ShortWeierstrassModel):
if not params.curve.parameters["b"].is_residue():
return None
y = params.curve.parameters["b"].sqrt()
# TODO: We can take the negative as well.
return Point(
AffineCoordinateModel(params.curve.model), x=Mod(0, params.curve.prime), y=y
)
elif isinstance(params.curve.model, MontgomeryModel):
return Point(
AffineCoordinateModel(params.curve.model),
x=Mod(0, params.curve.prime),
y=Mod(0, params.curve.prime),
)
else:
raise NotImplementedError
@public
def rpa_point_x0(params: DomainParameters) -> Optional[Point]:
"""Construct an (affine) [RPA]_ point (x, 0) for given domain parameters."""
if isinstance(params.curve.model, ShortWeierstrassModel):
if (params.order * params.cofactor) % 2 != 0:
return None
k = FF(params.curve.prime)
expr = sympify("x^3 + a * x + b", evaluate=False)
expr = expr.subs("a", k(int(params.curve.parameters["a"])))
expr = expr.subs("b", k(int(params.curve.parameters["b"])))
poly = Poly(expr, symbols("x"), domain=k)
roots = poly.ground_roots()
if not roots:
return None
x = Mod(int(next(iter(roots.keys()))), params.curve.prime)
return Point(
AffineCoordinateModel(params.curve.model), x=x, y=Mod(0, params.curve.prime)
)
elif isinstance(params.curve.model, MontgomeryModel):
return Point(
AffineCoordinateModel(params.curve.model),
x=Mod(0, params.curve.prime),
y=Mod(0, params.curve.prime),
)
else:
raise NotImplementedError
@public
def rpa_input_point(k: Mod, rpa_point: Point, params: DomainParameters) -> Point:
"""Construct an (affine) input point P that will lead to an RPA point [k]P."""
kinv = k.inverse()
return params.curve.affine_multiply(rpa_point, int(kinv))
@public
def rpa_distinguish(
params: DomainParameters,
multipliers: List[ScalarMultiplier],
oracle: Callable[[int, Point], bool],
bound: Optional[int] = None,
majority: int = 1,
use_init: bool = True,
use_multiply: bool = True,
) -> Set[ScalarMultiplier]:
"""
Distinguish the scalar multiplier used (from the possible :paramref:`~.rpa_distinguish.multipliers`) using
an [RPA]_ :paramref:`~.rpa_distinguish.oracle`.
:param params: The domain parameters to use.
:param multipliers: The list of possible multipliers.
:param oracle: An oracle that returns `True` when an RPA point is encountered during scalar multiplication of the input by the scalar.
:param bound: A bound on the size of the scalar to consider.
:param majority: Query the oracle up to `majority` times and take the majority vote of the results.
:param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
:param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: The list of possible multipliers after distinguishing (ideally just one).
"""
if (majority % 2) == 0:
raise ValueError("Cannot use even majority.")
if not (use_init or use_multiply):
raise ValueError("Has to use either init or multiply or both.")
P0 = rpa_point_0y(params)
if not P0:
raise ValueError("There are no RPA-points on the provided curve.")
log(f"Got RPA point {P0}.")
if not bound:
bound = params.order
mults = set(copy(mult) for mult in multipliers)
init_contexts = {}
for mult in mults:
with local(MultipleContext()) as ctx:
mult.init(params, params.generator)
init_contexts[mult] = ctx
tries = 0
while True:
if tries > 10:
warn("Tried more than 10 times. Aborting.")
return mults
scalar = int(Mod.random(bound))
log(f"Got scalar {scalar}")
log([mult.__class__.__name__ for mult in mults])
mults_to_multiples = {}
for mult in mults:
# Copy the context after init to not accumulate multiples by accident here.
init_context = deepcopy(init_contexts[mult])
# Take the computed points during init
init_points = set(init_context.parents.keys())
# And get their parents (inputs to formulas)
init_parents = set(
sum((init_context.parents[point] for point in init_points), [])
)
# Go over the parents and map them to multiples of the base (plus-minus sign)
init_multiples = set(
map(
lambda v: Mod(v, params.order),
(init_context.points[parent] for parent in init_parents),
)
)
init_multiples |= set(map(lambda v: -v, init_multiples))
# Now do the multiply and repeat the above, but only consider new computed points
with local(init_context) as ctx:
mult.multiply(scalar)
all_points = set(ctx.parents.keys())
multiply_parents = set(
sum((ctx.parents[point] for point in all_points - init_points), [])
)
multiply_multiples = set(
map(
lambda v: Mod(v, params.order),
(ctx.points[parent] for parent in multiply_parents),
)
)
multiply_multiples |= set(map(lambda v: -v, multiply_multiples))
used = set()
if use_init:
used |= init_multiples
if use_multiply:
used |= multiply_multiples
mults_to_multiples[mult] = used
dmap = Map.from_sets(set(mults), mults_to_multiples)
tree = Tree.build(set(mults), dmap)
log("Built distinguishing tree.")
log(tree.render())
if tree.size == 1:
tries += 1
continue
current_node = tree.root
while current_node.children:
best_distinguishing_multiple: Mod = current_node.dmap_input # type: ignore
P0_inverse = rpa_input_point(best_distinguishing_multiple, P0, params)
responses = []
for _ in range(majority):
responses.append(oracle(scalar, P0_inverse))
if responses.count(True) > (majority // 2):
response = True
break
if responses.count(False) > (majority // 2):
response = False
break
log(f"Oracle response -> {response}")
for mult in mults:
log(
mult.__class__.__name__,
best_distinguishing_multiple in mults_to_multiples[mult],
)
response_map = {child.response: child for child in current_node.children}
current_node = response_map[response]
mults = current_node.cfgs
log([mult.__class__.__name__ for mult in mults])
log()
if len(mults) == 1:
return mults
|