diff options
Diffstat (limited to 'epare/common.py')
| -rw-r--r-- | epare/common.py | 132 |
1 files changed, 127 insertions, 5 deletions
diff --git a/epare/common.py b/epare/common.py index 3b10007..4689ce5 100644 --- a/epare/common.py +++ b/epare/common.py @@ -10,10 +10,12 @@ from dataclasses import dataclass from functools import partial, cached_property from importlib import import_module, invalidate_caches from pathlib import Path -from typing import Type, Any +from typing import Type, Any, Optional +from enum import Enum from pyecsca.ec.params import DomainParameters, get_params from pyecsca.ec.mult import * +from pyecsca.ec.countermeasures import GroupScalarRandomization, AdditiveSplitting, MultiplicativeSplitting, EuclideanSplitting spawn_context = multiprocessing.get_context("spawn") @@ -42,24 +44,46 @@ class MultIdent: klass: Type[ScalarMultiplier] args: list[Any] kwargs: dict[str, Any] + countermeasure: Optional[str] = None def __init__(self, klass: Type[ScalarMultiplier], *args, **kwargs): object.__setattr__(self, "klass", klass) object.__setattr__(self, "args", args if args is not None else []) + if kwargs is not None and "countermeasure" in kwargs: + object.__setattr__(self, "countermeasure", kwargs["countermeasure"]) + del kwargs["countermeasure"] object.__setattr__(self, "kwargs", kwargs if kwargs is not None else {}) - + @cached_property def partial(self): - return partial(self.klass, *self.args, **self.kwargs) + func = partial(self.klass, *self.args, **self.kwargs) + if self.countermeasure is None: + return func + if self.countermeasure == "gsr": + return lambda *args, **kwargs: GroupScalarRandomization(func(*args, **kwargs)) + elif self.countermeasure == "additive": + return lambda *args, **kwargs: AdditiveSplitting(func(*args, **kwargs)) + elif self.countermeasure == "multiplicative": + return lambda *args, **kwargs: MultiplicativeSplitting(func(*args, **kwargs)) + elif self.countermeasure == "euclidean": + return lambda *args, **kwargs: EuclideanSplitting(func(*args, **kwargs)) + + def with_countermeasure(self, countermeasure: str): + if countermeasure not in (None, "gsr", "additive", "multiplicative", "euclidean"): + raise ValueError(f"Unknown countermeasure: {countermeasure}") + return MultIdent(self.klass, *self.args, **self.kwargs, countermeasure=countermeasure) def __str__(self): - return f"{self.klass.__name__}_{self.args}_{self.kwargs}" + args = ("_" + ",".join(list(map(str, self.args)))) if self.args else "" + kwargs = ("_" + ",".join(f"{str(k)}:{v.name if isinstance(v, Enum) else str(v)}" for k,v in self.kwargs.items())) if self.kwargs else "" + countermeasure = f"+{self.countermeasure}" if self.countermeasure is not None else "" + return f"{self.klass.__name__}{args}{kwargs}{countermeasure}" def __repr__(self): return str(self) def __hash__(self): - return hash((self.klass, tuple(self.args), tuple(self.kwargs.keys()), tuple(self.kwargs.values()))) + return hash((self.klass, self.countermeasure, tuple(self.args), tuple(self.kwargs.keys()), tuple(self.kwargs.values()))) @dataclass @@ -85,3 +109,101 @@ class MultResults: def __repr__(self): return str(self) + + +@dataclass +class ProbMap: + probs: dict[int, float] + samples: int + + def __len__(self): + return len(self.probs) + + def __iter__(self): + yield from self.probs + + def __getitem__(self, i): + return self.probs[i] + + def keys(self): + return self.probs.keys() + + def values(self): + return self.probs.values() + + def items(self): + return self.probs.items() + + def merge(self, other: "ProbMap"): + new_keys = set(self.keys()).union(other.keys()) + result = {} + for key in new_keys: + if key in self and key in other: + result[key] = (self[key] * self.samples + other[key] * other.samples) / (self.samples + other.samples) + elif key in self: + result[key] = self[key] + elif key in other: + result[key] = other[key] + self.probs = result + self.samples += other.samples + + def enrich(self, other: "ProbMap"): + if self.samples != other.samples: + raise ValueError("Enriching can only work on equal amount of samples (same run, different divisors)") + self.probs.update(other.probs) + +# All dbl-and-add multipliers from https://github.com/J08nY/pyecsca/blob/master/pyecsca/ec/mult + +window_mults = [ + MultIdent(SlidingWindowMultiplier, width=3), + MultIdent(SlidingWindowMultiplier, width=4), + MultIdent(SlidingWindowMultiplier, width=5), + MultIdent(SlidingWindowMultiplier, width=6), + MultIdent(FixedWindowLTRMultiplier, m=2**4), + MultIdent(FixedWindowLTRMultiplier, m=2**5), + MultIdent(FixedWindowLTRMultiplier, m=2**6), + MultIdent(WindowBoothMultiplier, width=3), + MultIdent(WindowBoothMultiplier, width=4), + MultIdent(WindowBoothMultiplier, width=5), + MultIdent(WindowBoothMultiplier, width=6) +] +naf_mults = [ + MultIdent(WindowNAFMultiplier, width=3), + MultIdent(WindowNAFMultiplier, width=4), + MultIdent(WindowNAFMultiplier, width=5), + MultIdent(WindowNAFMultiplier, width=6), + MultIdent(BinaryNAFMultiplier, direction=ProcessingDirection.LTR), + MultIdent(BinaryNAFMultiplier, direction=ProcessingDirection.RTL) +] +comb_mults = [ + MultIdent(CombMultiplier, width=2), + MultIdent(CombMultiplier, width=3), + MultIdent(CombMultiplier, width=4), + MultIdent(CombMultiplier, width=5), + MultIdent(CombMultiplier, width=6), + MultIdent(BGMWMultiplier, width=2, direction=ProcessingDirection.LTR), + MultIdent(BGMWMultiplier, width=3, direction=ProcessingDirection.LTR), + MultIdent(BGMWMultiplier, width=4, direction=ProcessingDirection.LTR), + MultIdent(BGMWMultiplier, width=5, direction=ProcessingDirection.LTR), + MultIdent(BGMWMultiplier, width=6, direction=ProcessingDirection.LTR), + MultIdent(BGMWMultiplier, width=2, direction=ProcessingDirection.RTL), + MultIdent(BGMWMultiplier, width=3, direction=ProcessingDirection.RTL), + MultIdent(BGMWMultiplier, width=4, direction=ProcessingDirection.RTL), + MultIdent(BGMWMultiplier, width=5, direction=ProcessingDirection.RTL), + MultIdent(BGMWMultiplier, width=6, direction=ProcessingDirection.RTL) +] +binary_mults = [ + MultIdent(LTRMultiplier, always=False), + MultIdent(LTRMultiplier, always=True), + MultIdent(RTLMultiplier, always=False), + MultIdent(RTLMultiplier, always=True), + MultIdent(CoronMultiplier) +] +other_mults = [ + MultIdent(FullPrecompMultiplier, always=False), + MultIdent(FullPrecompMultiplier, always=True), + MultIdent(SimpleLadderMultiplier, complete=True), + MultIdent(SimpleLadderMultiplier, complete=False) +] + +all_mults = window_mults + naf_mults + binary_mults + other_mults + comb_mults |
