diff options
Diffstat (limited to 'pyecsca/sca/re/rpa.py')
| -rw-r--r-- | pyecsca/sca/re/rpa.py | 95 |
1 files changed, 59 insertions, 36 deletions
diff --git a/pyecsca/sca/re/rpa.py b/pyecsca/sca/re/rpa.py index a150b9c..3fd87ba 100644 --- a/pyecsca/sca/re/rpa.py +++ b/pyecsca/sca/re/rpa.py @@ -427,9 +427,7 @@ def multiple_graph( params: DomainParameters, mult_class: Type[ScalarMultiplier], mult_factory: Callable, - use_init: bool = True, - use_multiply: bool = True, -) -> Tuple[MultipleContext, Point]: +) -> Tuple[MultipleContext, MultipleContext, Point]: """ Compute the multiples computed for a given scalar and multiplier (quickly). @@ -437,29 +435,22 @@ def multiple_graph( :param params: The domain parameters to use. :param mult_class: The class of the scalar multiplier to use. :param mult_factory: A callable that takes the formulas and instantiates the multiplier. - :param use_init: Whether to consider the point multiples that happen in scalarmult initialization. - :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization). :return: The context with the computed multiples and the resulting point. """ mult = cached_fake_mult(mult_class, mult_factory, params) ctx = MultipleContext(keep_base=True) - if use_init: - with local(ctx, copy=False): - mult.init(params, FakePoint(params.curve.coordinate_model)) - else: + with local(ctx, copy=False) as precomp_ctx: mult.init(params, FakePoint(params.curve.coordinate_model)) - if use_multiply: - with local(ctx, copy=False): - out = mult.multiply(scalar) - else: + with local(ctx, copy=True) as full_ctx: out = mult.multiply(scalar) - return ctx, out + return precomp_ctx, full_ctx, out @public def multiples_from_graph( - ctx: MultipleContext, + precomp_ctx: MultipleContext, + full_ctx: MultipleContext, out: Point, kind: Union[ Literal["all"], @@ -467,38 +458,72 @@ def multiples_from_graph( Literal["necessary"], Literal["precomp+necessary"], ] = "all", + use_init: bool = True, + use_multiply: bool = True, ): """ - :param ctx: + :param precomp_ctx: + :param full_ctx: :param out: - :param kind: + :param kind: The kind of multiples to return. Can be one of "all", "input", "necessary", or "precomp+necessary". + :param use_init: Whether to consider the point multiples that happen in scalarmult initialization. + :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization). :return: A set of multiples computed for the scalar. """ - if kind == "all": - res = set(ctx.points.values()) - elif kind == "input": + if not use_init and not use_multiply: + return set() + + def _all(ctx): + return set(ctx.points.values()) + + def _input(ctx): res = set() for point, multiple in ctx.points.items(): if point in ctx.parents: for parent in ctx.parents[point]: res.add(ctx.points[parent]) - elif kind == "necessary": + return res + + def _necessary(ctx, for_what): res = {ctx.points[out]} - queue = {out} + queue = {*for_what} while queue: point = queue.pop() for parent in ctx.parents[point]: res.add(ctx.points[parent]) queue.add(parent) + return res + + if kind == "all": + if use_init and use_multiply: + res = _all(full_ctx) + elif use_init: + res = _all(precomp_ctx) + elif use_multiply: + res = _all(full_ctx) - _all(precomp_ctx) + elif kind == "input": + if use_init and use_multiply: + res = _input(full_ctx) + elif use_init: + res = _input(precomp_ctx) + elif use_multiply: + res = _input(full_ctx) - _input(precomp_ctx) + elif kind == "necessary": + for_what = {out} + if use_init and use_multiply: + res = _necessary(full_ctx, for_what) + elif use_init: + res = _necessary(full_ctx, for_what) & _all(precomp_ctx) + elif use_multiply: + res = _necessary(full_ctx, for_what) - _all(precomp_ctx) elif kind == "precomp+necessary": - res = {ctx.points[out]} - queue = {out, *ctx.precomp.values()} - while queue: - point = queue.pop() - for parent in ctx.parents[point]: - res.add(ctx.points[parent]) - queue.add(parent) + if use_init and use_multiply: + res = _necessary(full_ctx, {out, *precomp_ctx.precomp.values()}) + elif use_init: + res = _necessary(precomp_ctx, {*precomp_ctx.precomp.values()}) + elif use_multiply: + res = _necessary(full_ctx, {out}) - _all(precomp_ctx) else: raise ValueError(f"Invalid kind {kind}") return res - {0} @@ -533,13 +558,11 @@ def multiples_computed( .. note:: The scalar multiplier must not short-circuit. - If `kind` is not "all", `use_init` must be `True`. """ - if kind != "all" and not use_init: - raise ValueError("Cannot use kind other than 'all' with use_init=False.") + precomp_ctx, full_ctx, out = multiple_graph(scalar, params, mult_class, mult_factory) - ctx, out = multiple_graph( - scalar, params, mult_class, mult_factory, use_init, use_multiply + return ( + multiples_from_graph(precomp_ctx, full_ctx, out, kind, use_init, use_multiply) + if bool(precomp_ctx) and bool(full_ctx) + else set() ) - - return multiples_from_graph(ctx, out, kind) if ctx else set() |
