aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/sca
diff options
context:
space:
mode:
authorJ08nY2025-08-21 16:48:52 +0200
committerJ08nY2025-08-21 16:48:52 +0200
commit8f95d0ff284cc48db26c1b916b548b6ad5967dfe (patch)
tree3576f0750d3ddef5fab5909507c3157658ee92f0 /pyecsca/sca
parentc36ee8a2c07ec58b556e505e4ecf61c871dc94c9 (diff)
downloadpyecsca-8f95d0ff284cc48db26c1b916b548b6ad5967dfe.tar.gz
pyecsca-8f95d0ff284cc48db26c1b916b548b6ad5967dfe.tar.zst
pyecsca-8f95d0ff284cc48db26c1b916b548b6ad5967dfe.zip
Diffstat (limited to 'pyecsca/sca')
-rw-r--r--pyecsca/sca/re/epa.py63
-rw-r--r--pyecsca/sca/re/rpa.py95
2 files changed, 109 insertions, 49 deletions
diff --git a/pyecsca/sca/re/epa.py b/pyecsca/sca/re/epa.py
index 5e2b78a..af07bef 100644
--- a/pyecsca/sca/re/epa.py
+++ b/pyecsca/sca/re/epa.py
@@ -12,10 +12,13 @@ from pyecsca.sca.re.rpa import MultipleContext
@public
def graph_to_check_inputs(
- ctx: MultipleContext,
+ precomp_ctx: MultipleContext,
+ full_ctx: MultipleContext,
out: Point,
check_condition: Union[Literal["all"], Literal["necessary"]],
precomp_to_affine: bool,
+ use_init: bool = True,
+ use_multiply: bool = True,
) -> dict[str, list[tuple[int, ...]]]:
"""
Compute the inputs for the checks based on the context and output point. This function traverses the graph of points
@@ -27,36 +30,65 @@ def graph_to_check_inputs(
:param out: The output point of the computation.
:param check_condition: Whether to check all points or only those necessary for the output point.
:param precomp_to_affine: Whether to include the precomputed points in the to-affine checks.
+ :param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
+ :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: A dictionary mapping formula names to lists of tuples of input multiples.
.. note::
The scalar multiplier must not short-circuit.
"""
- affine_points = {out, *ctx.precomp.values()} if precomp_to_affine else {out}
- if check_condition == "all":
- points = set(ctx.points.keys())
- elif check_condition == "necessary":
- points = set(affine_points)
- queue = set(affine_points)
+ if not use_init and not use_multiply:
+ raise ValueError("At least one of use_init or use_multiply must be True.")
+
+ affine_points: set[Point] = set()
+ if use_init and use_multiply:
+ affine_points = (
+ {out, *precomp_ctx.precomp.values()} if precomp_to_affine else {out}
+ )
+ elif use_init:
+ affine_points = {*precomp_ctx.precomp.values()} if precomp_to_affine else set()
+ elif use_multiply:
+ affine_points = {out}
+
+ def _necessary(ctx, for_what):
+ res = {out}
+ queue = {*for_what}
while queue:
point = queue.pop()
for parent in ctx.parents[point]:
- points.add(parent)
+ res.add(parent)
queue.add(parent)
+ return res
+
+ points: set[Point] = set()
+ if check_condition == "all":
+ if use_init and use_multiply:
+ points = set(full_ctx.points.keys())
+ elif use_init:
+ points = set(precomp_ctx.points.keys())
+ elif use_multiply:
+ points = set(full_ctx.points.keys()) - set(precomp_ctx.points.keys())
+ elif check_condition == "necessary":
+ if use_init and use_multiply:
+ points = _necessary(full_ctx, affine_points)
+ elif use_init:
+ points = _necessary(full_ctx, affine_points) & set(precomp_ctx.points.keys())
+ elif use_multiply:
+ points = _necessary(full_ctx, affine_points) - set(precomp_ctx.points.keys())
else:
raise ValueError("check_condition must be 'all' or 'necessary'")
# Special case the "to affine" transform and checks
formula_checks: dict[str, list[tuple[int, ...]]] = {
- "affine": [(ctx.points[point],) for point in affine_points]
+ "affine": [(full_ctx.points[point],) for point in affine_points]
}
# This actually passes the multiple itself to the check, not the inputs(parents)
# Now handle the regular checks
for point in points:
- formula = ctx.formulas[point]
+ formula = full_ctx.formulas[point]
if not formula:
# Skip input point or infty point (they magically appear and do not have an origin formula)
continue
- inputs = tuple(map(lambda pt: ctx.points[pt], ctx.parents[point]))
+ inputs = tuple(map(lambda pt: full_ctx.points[pt], full_ctx.parents[point]))
check_list = formula_checks.setdefault(formula, [])
check_list.append(inputs)
return formula_checks
@@ -92,11 +124,14 @@ def evaluate_checks(
@public
def errors_out(
- ctx: MultipleContext,
+ precomp_ctx: MultipleContext,
+ full_ctx: MultipleContext,
out: Point,
check_funcs: dict[str, Callable],
check_condition: Union[Literal["all"], Literal["necessary"]],
precomp_to_affine: bool,
+ use_init: bool = True,
+ use_multiply: bool = True,
) -> bool:
"""
Check whether the computation errors out based on the provided context, output point, and check functions.
@@ -110,10 +145,12 @@ def errors_out(
of the base point and `q` is the base point order.
:param check_condition: Whether to check all points or only those necessary for the output point.
:param precomp_to_affine: Whether to include the precomputed points in the to-affine checks.
+ :param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
+ :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: Whether any of the checks returned True -> whether the computation errors out.
.. note::
The scalar multiplier must not short-circuit.
"""
- formula_checks = graph_to_check_inputs(ctx, out, check_condition, precomp_to_affine)
+ formula_checks = graph_to_check_inputs(precomp_ctx, full_ctx, out, check_condition, precomp_to_affine, use_init, use_multiply)
return evaluate_checks(check_funcs, formula_checks)
diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py
index a150b9c..3fd87ba 100644
--- a/pyecsca/sca/re/rpa.py
+++ b/pyecsca/sca/re/rpa.py
@@ -427,9 +427,7 @@ def multiple_graph(
params: DomainParameters,
mult_class: Type[ScalarMultiplier],
mult_factory: Callable,
- use_init: bool = True,
- use_multiply: bool = True,
-) -> Tuple[MultipleContext, Point]:
+) -> Tuple[MultipleContext, MultipleContext, Point]:
"""
Compute the multiples computed for a given scalar and multiplier (quickly).
@@ -437,29 +435,22 @@ def multiple_graph(
:param params: The domain parameters to use.
:param mult_class: The class of the scalar multiplier to use.
:param mult_factory: A callable that takes the formulas and instantiates the multiplier.
- :param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
- :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: The context with the computed multiples and the resulting point.
"""
mult = cached_fake_mult(mult_class, mult_factory, params)
ctx = MultipleContext(keep_base=True)
- if use_init:
- with local(ctx, copy=False):
- mult.init(params, FakePoint(params.curve.coordinate_model))
- else:
+ with local(ctx, copy=False) as precomp_ctx:
mult.init(params, FakePoint(params.curve.coordinate_model))
- if use_multiply:
- with local(ctx, copy=False):
- out = mult.multiply(scalar)
- else:
+ with local(ctx, copy=True) as full_ctx:
out = mult.multiply(scalar)
- return ctx, out
+ return precomp_ctx, full_ctx, out
@public
def multiples_from_graph(
- ctx: MultipleContext,
+ precomp_ctx: MultipleContext,
+ full_ctx: MultipleContext,
out: Point,
kind: Union[
Literal["all"],
@@ -467,38 +458,72 @@ def multiples_from_graph(
Literal["necessary"],
Literal["precomp+necessary"],
] = "all",
+ use_init: bool = True,
+ use_multiply: bool = True,
):
"""
- :param ctx:
+ :param precomp_ctx:
+ :param full_ctx:
:param out:
- :param kind:
+ :param kind: The kind of multiples to return. Can be one of "all", "input", "necessary", or "precomp+necessary".
+ :param use_init: Whether to consider the point multiples that happen in scalarmult initialization.
+ :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization).
:return: A set of multiples computed for the scalar.
"""
- if kind == "all":
- res = set(ctx.points.values())
- elif kind == "input":
+ if not use_init and not use_multiply:
+ return set()
+
+ def _all(ctx):
+ return set(ctx.points.values())
+
+ def _input(ctx):
res = set()
for point, multiple in ctx.points.items():
if point in ctx.parents:
for parent in ctx.parents[point]:
res.add(ctx.points[parent])
- elif kind == "necessary":
+ return res
+
+ def _necessary(ctx, for_what):
res = {ctx.points[out]}
- queue = {out}
+ queue = {*for_what}
while queue:
point = queue.pop()
for parent in ctx.parents[point]:
res.add(ctx.points[parent])
queue.add(parent)
+ return res
+
+ if kind == "all":
+ if use_init and use_multiply:
+ res = _all(full_ctx)
+ elif use_init:
+ res = _all(precomp_ctx)
+ elif use_multiply:
+ res = _all(full_ctx) - _all(precomp_ctx)
+ elif kind == "input":
+ if use_init and use_multiply:
+ res = _input(full_ctx)
+ elif use_init:
+ res = _input(precomp_ctx)
+ elif use_multiply:
+ res = _input(full_ctx) - _input(precomp_ctx)
+ elif kind == "necessary":
+ for_what = {out}
+ if use_init and use_multiply:
+ res = _necessary(full_ctx, for_what)
+ elif use_init:
+ res = _necessary(full_ctx, for_what) & _all(precomp_ctx)
+ elif use_multiply:
+ res = _necessary(full_ctx, for_what) - _all(precomp_ctx)
elif kind == "precomp+necessary":
- res = {ctx.points[out]}
- queue = {out, *ctx.precomp.values()}
- while queue:
- point = queue.pop()
- for parent in ctx.parents[point]:
- res.add(ctx.points[parent])
- queue.add(parent)
+ if use_init and use_multiply:
+ res = _necessary(full_ctx, {out, *precomp_ctx.precomp.values()})
+ elif use_init:
+ res = _necessary(precomp_ctx, {*precomp_ctx.precomp.values()})
+ elif use_multiply:
+ res = _necessary(full_ctx, {out}) - _all(precomp_ctx)
else:
raise ValueError(f"Invalid kind {kind}")
return res - {0}
@@ -533,13 +558,11 @@ def multiples_computed(
.. note::
The scalar multiplier must not short-circuit.
- If `kind` is not "all", `use_init` must be `True`.
"""
- if kind != "all" and not use_init:
- raise ValueError("Cannot use kind other than 'all' with use_init=False.")
+ precomp_ctx, full_ctx, out = multiple_graph(scalar, params, mult_class, mult_factory)
- ctx, out = multiple_graph(
- scalar, params, mult_class, mult_factory, use_init, use_multiply
+ return (
+ multiples_from_graph(precomp_ctx, full_ctx, out, kind, use_init, use_multiply)
+ if bool(precomp_ctx) and bool(full_ctx)
+ else set()
)
-
- return multiples_from_graph(ctx, out, kind) if ctx else set()