diff options
| author | J08nY | 2025-03-08 15:31:09 +0100 |
|---|---|---|
| committer | J08nY | 2025-03-08 15:31:09 +0100 |
| commit | 81f4d8ff8dab2a7884f0d74f23b3d78142eb176a (patch) | |
| tree | 9d8e36a12f0bb1068bbbbd7c749a7beea1b8fd5d | |
| parent | 5837599a7c80237776196e4c451810e121b74c54 (diff) | |
| download | pyecsca-81f4d8ff8dab2a7884f0d74f23b3d78142eb176a.tar.gz pyecsca-81f4d8ff8dab2a7884f0d74f23b3d78142eb176a.tar.zst pyecsca-81f4d8ff8dab2a7884f0d74f23b3d78142eb176a.zip | |
Improve context handling and make some stuff zero copy.
| -rw-r--r-- | pyecsca/ec/context.py | 16 | ||||
| -rw-r--r-- | pyecsca/sca/re/zvp.py | 23 | ||||
| -rw-r--r-- | test/ec/test_context.py | 32 |
3 files changed, 61 insertions, 10 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py index d64cbbc..674393b 100644 --- a/pyecsca/ec/context.py +++ b/pyecsca/ec/context.py @@ -308,9 +308,14 @@ current: Optional[Context] = None class _ContextManager: - def __init__(self, new_context): - # TODO: Is this deepcopy a good idea? - self.new_context = deepcopy(new_context) + def __init__(self, new_context: Optional[Context] = None, copy: bool = True): + if copy: + if new_context is not None: + self.new_context = deepcopy(new_context) + else: + self.new_context = deepcopy(current) + else: + self.new_context = new_context if new_context is not None else current def __enter__(self) -> Optional[Context]: global current # This is OK, skipcq: PYL-W0603 @@ -324,7 +329,7 @@ class _ContextManager: @public -def local(ctx: Optional[Context] = None) -> ContextManager: +def local(ctx: Optional[Context] = None, copy: bool = True) -> ContextManager: """ Use a local context. @@ -337,6 +342,7 @@ def local(ctx: Optional[Context] = None) -> ContextManager: True :param ctx: If none, current context is copied. + :param copy: Whether to copy the context. :return: A context manager. """ - return _ContextManager(ctx) + return _ContextManager(ctx, copy) diff --git a/pyecsca/sca/re/zvp.py b/pyecsca/sca/re/zvp.py index 298158a..74c8902 100644 --- a/pyecsca/sca/re/zvp.py +++ b/pyecsca/sca/re/zvp.py @@ -590,6 +590,8 @@ def addition_chain( params: DomainParameters, mult_class: Type[ScalarMultiplier], mult_factory, + use_init: bool = False, + use_multiply: bool = True ) -> List[Tuple[str, Tuple[int, ...]]]: """ Compute the addition chain for a given scalar and multiplier. @@ -598,6 +600,8 @@ def addition_chain( :param params: The domain parameters to use. :param mult_class: The class of the scalar multiplier to use. :param mult_factory: A callable that takes the formulas and instantiates the multiplier. + :param use_init: Whether to consider the point multiples that happen in scalarmult initialization. + :param use_multiply: Whether to consider the point multiples that happen in scalarmult multiply (after initialization). :return: A list of tuples, where the first element is the formula shortname (e.g. "add") and the second is a tuple of the dlog relationships to the input of the input points to the formula. """ @@ -618,17 +622,26 @@ def addition_chain( for subclass in formula.__subclasses__(): if issubclass(subclass, FakeFormula): formulas.append(subclass(params.curve.coordinate_model)) + + ctx = MultipleContext() mult = mult_factory(*formulas) - mult.init(params, FakePoint(params.curve.coordinate_model)) + if use_init: + with local(ctx, copy=False): + mult.init(params, FakePoint(params.curve.coordinate_model)) + else: + mult.init(params, FakePoint(params.curve.coordinate_model)) - with local(MultipleContext()) as mctx: + if use_multiply: + with local(ctx, copy=False): + mult.multiply(scalar) + else: mult.multiply(scalar) chain = [] - for point, parents in mctx.parents.items(): + for point, parents in ctx.parents.items(): if not parents: continue - formula_type = mctx.formulas[point] - ks = tuple(mctx.points[parent] for parent in parents) + formula_type = ctx.formulas[point] + ks = tuple(ctx.points[parent] for parent in parents) chain.append((formula_type, ks)) return chain diff --git a/test/ec/test_context.py b/test/ec/test_context.py index 9ad8b52..f76995b 100644 --- a/test/ec/test_context.py +++ b/test/ec/test_context.py @@ -89,6 +89,38 @@ def test_default_no_enter(): with local(DefaultContext()) as default, pytest.raises(ValueError): default.exit_action(RandomModAction(7)) +def test_multiple_enter(mult): + default = DefaultContext() + with local(default) as ctx1: + mult.multiply(59) + + with local(default) as ctx2: + mult.multiply(135) + + assert len(default.actions) == 0 + assert len(ctx1.actions) == len(ctx2.actions) + +def test_multiple_enter_chained(mult): + default = DefaultContext() + with local(default) as ctx1: + mult.multiply(59) + + with local(ctx1) as ctx2: + mult.multiply(135) + + assert len(default.actions) == 0 + assert 2 * len(ctx1.actions) == len(ctx2.actions) + +def test_multiple_enter_no_copy(mult): + default = DefaultContext() + with local(default, copy=False) as ctx1: + mult.multiply(59) + + with local(default, copy=False) as ctx2: + mult.multiply(135) + + assert len(default.actions) == len(ctx1.actions) + assert len(ctx1.actions) == len(ctx2.actions) def test_path(mult, secp128r1): with local(PathContext([0, 1])) as ctx: |
