diff options
| author | J08nY | 2025-08-05 21:13:56 +0200 |
|---|---|---|
| committer | J08nY | 2025-08-05 21:13:56 +0200 |
| commit | bb50056bf5c2834e7814681e9812d32ffd39a030 (patch) | |
| tree | 554085ccc655164f4707891e89e1e908538a9cf7 | |
| parent | e6e2cf8b31996c8dc42eae20c5afad40ba382c38 (diff) | |
| download | ECTester-bb50056bf5c2834e7814681e9812d32ffd39a030.tar.gz ECTester-bb50056bf5c2834e7814681e9812d32ffd39a030.tar.zst ECTester-bb50056bf5c2834e7814681e9812d32ffd39a030.zip | |
Add docs for common functions.
| -rw-r--r-- | analysis/scalarmults/common.py | 82 |
1 files changed, 68 insertions, 14 deletions
diff --git a/analysis/scalarmults/common.py b/analysis/scalarmults/common.py index d3a784c..957afe9 100644 --- a/analysis/scalarmults/common.py +++ b/analysis/scalarmults/common.py @@ -17,10 +17,12 @@ from pyecsca.ec.countermeasures import GroupScalarRandomization, AdditiveSplitti def check_equal_multiples(k, l, q): + """Checks whether the two multiples input into the formula are equal modulo q (the order of the base).""" return (k % q) == (l % q) def check_divides(k, l, q): + """Checks whether q (the order of the base) divides any of the multiples input into the formula.""" return (k % q == 0) or (l % q == 0) @@ -29,10 +31,12 @@ def check_half_add(k, l, q): def check_affine(k, q): + """Checks whether q (the order of the base) divides the multiple that is to be converted to affine.""" return k % q == 0 def check_any(*checks, q=None): + """Merge multiple checks together. The returned check function no longer takes the `q` parameter.""" def check_func(k, l): for check in checks: if check(k, l, q): @@ -41,24 +45,20 @@ def check_any(*checks, q=None): return check_func +# These checks can be applied to add formulas. See the formulas notebook for background on them. checks_add = { "equal_multiples": check_equal_multiples, "divides": check_divides, "half_add": check_half_add } +# This check can be applied to conversion to affine. checks_affine = { "affine": check_affine } - -def powers_of(k, max_power=20): - return [k**i for i in range(1, max_power)] - -def prod_combine(one, other): - return [a * b for a, b in itertools.product(one, other)] - def powerset(iterable): + """Take an iterable and create a powerset of its elements.""" s = list(iterable) return map(set, itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))) @@ -66,6 +66,20 @@ def powerset(iterable): @dataclass(frozen=True) @total_ordering class ErrorModel: + """ + An ErrorModel describes the behavior of an implementation with regards to errors on exceptional + inputs to its addition formulas, to-affine conversion or general scalar multiplication. + + :param checks: A set of names of checks (from checks_add and checks_affine) that the implementation performs. + Note that these may not be checks that the implementation explicitly performs, only that it behaves w.r.t. + errors as if it were doing these checks, due to the formulas it chose and any actual checks it has. + :param check_condition: Either "all" or "necessary". Specifies whether the checks are applied to all points + that the implementation computes during a scalar multiplication or only those that end up being used -- thus + affect -- the final result. If an implementation does not perform any dummy operations, these two are the same. + :param precomp_to_affine: Specifies whether the implementation converts all results of the precomputation step + to affine form. If it does, it means that additional checks on all outputs of the precomputation are done as + they have to be "convertible" to affine form. + """ checks: set[str] check_condition: Union[Literal["all"], Literal["necessary"]] precomp_to_affine: bool @@ -83,11 +97,13 @@ class ErrorModel: object.__setattr__(self, "precomp_to_affine", precomp_to_affine) def check_add(self, q): + """Get the add formula check function for the given q.""" if self.checks == {"affine"}: return lambda k, l: False return check_any(*map(lambda name: checks_add[name], filter(lambda check: check in checks_add, self.checks)), q=q) def check_affine(self, q): + """Get the to-affine check function for the given q.""" return partial(check_affine, q=q) def __lt__(self, other): @@ -112,6 +128,7 @@ class ErrorModel: return hash((tuple(sorted(self.checks)), self.check_condition, self.precomp_to_affine)) +# All error models are a simple cartesian product of the individual options. all_error_models = [] for checks in powerset(checks_add): for precomp_to_affine in (True, False): @@ -123,6 +140,22 @@ for checks in powerset(checks_add): @dataclass(frozen=True) @total_ordering class MultIdent: + """ + A MultIdent is a description of a scalar multiplication implementation, consisting of a scalar multiplier, + (optionally) a countermeasure, and (optionally) an error model. + + The scalar multiplier is defined by the `klass` attribute, along with the `args` and `kwargs` attributes. + One can reconstruct the raw multiplier (without the countermeasure) by doing: + + klass(*args, **kwargs) + + The countermeasure is simply in the `countermeasure` attribute and may be `None`. + + The error model is simply in the `error_model` attribute and may be `None`. If it is `None`, the MultIdent + is not suitable for error simulation and merely represents the description of a scalar multiplication + implementation we care about when reverse-engineering: the multiplier and the countermeasure, we do not + really care about the error model, yet need it when simulating. + """ klass: Type[ScalarMultiplier] args: list[Any] kwargs: dict[str, Any] @@ -142,6 +175,7 @@ class MultIdent: @cached_property def partial(self): + """Get the callable that constructs the scalar multiplier (with countermeasure if any).""" func = partial(self.klass, *self.args, **self.kwargs) if self.countermeasure is None: return func @@ -157,11 +191,13 @@ class MultIdent: return lambda *args, **kwargs: BrumleyTuveri(func(*args, **kwargs)) def with_countermeasure(self, countermeasure: str | None): + """Return a new MultIdent with a given countermeasure.""" if countermeasure not in (None, "gsr", "additive", "multiplicative", "euclidean", "bt"): raise ValueError(f"Unknown countermeasure: {countermeasure}") return MultIdent(self.klass, *self.args, **self.kwargs, countermeasure=countermeasure) def with_error_model(self, error_model: ErrorModel | None): + """Return a new MultIdent with a given error model.""" if not (isinstance(error_model, ErrorModel) or error_model is None): raise ValueError("Unknown error model.") return MultIdent(self.klass, *self.args, **self.kwargs, countermeasure=self.countermeasure, error_model=error_model) @@ -197,6 +233,12 @@ class MultIdent: @dataclass class MultResults: + """ + A MultResults instance represents many simulated scalar multiplciation computations, which were tracked + using a `MultipleContext`. Generally, these would be for one MultIdent only, but that should be handled + separately, for example in a dict[MultIdent, MultResults]. The `samples` describe how many computations + are contained and must correspond to the length of the `multiplications` list. + """ multiplications: list[tuple[MultipleContext, Point]] samples: int duration: Optional[float] = None @@ -224,6 +266,14 @@ class MultResults: @dataclass class ProbMap: + """ + A ProbMap is a mapping from integers (base point order q) to floats (error probability for some scalar + multiplication implementation, i.e. MultIdent). The probability map is constructed for a given set of + `divisors` (the base point orders q). Probability maps can be narrowed or merged. A narrowing restricts + the probability map to a smaller set of `divisors`. A merging takes another probability map using the + same divisor set and updates the probabilities to a weighted average of the two probability maps + (the weight is the number of samples). + """ probs: dict[int, float] divisors_hash: bytes samples: int @@ -250,6 +300,7 @@ class ProbMap: return self.probs.items() def narrow(self, divisors: set[int]): + """Narrow the probability map to the new set of divisors (must be a subset of the current set).""" divisors_hash = hashlib.blake2b(str(sorted(divisors)).encode(), digest_size=8).digest() if self.divisors_hash == divisors_hash: # Already narrow. @@ -259,6 +310,7 @@ class ProbMap: self.divisors_hash = divisors_hash def merge(self, other: "ProbMap") -> None: + """Merge the `other` probability map into this one (must share the divisor set).""" if self.divisors_hash != other.divisors_hash: raise ValueError("Merging can only work on probmaps created for same divisors.") new_keys = set(self.keys()).union(other.keys()) @@ -270,13 +322,6 @@ class ProbMap: self.probs = result self.samples += other.samples - def enrich(self, other: "ProbMap") -> None: - if self.samples != other.samples: - raise ValueError("Enriching can only work on equal amount of samples (same run, different divisors).") - if self.divisors_hash == other.divisors_hash or set(self.keys()) & set(other.keys()): - raise ValueError("Enriching can only work on distinct divisors.") - self.probs.update(other.probs) - # All dbl-and-add multipliers from https://github.com/J08nY/pyecsca/blob/master/pyecsca/ec/mult window_mults = [ @@ -355,13 +400,21 @@ other_mults = [ MultIdent(SimpleLadderMultiplier, complete=False) ] +# We can enumerate all mults and countermeasures here. all_mults = window_mults + naf_mults + binary_mults + other_mults + comb_mults all_mults_with_ctr = [mult.with_countermeasure(ctr) for mult in all_mults for ctr in (None, "gsr", "additive", "multiplicative", "euclidean", "bt")] +def powers_of(k, max_power=20): + """Take all powers of `k` up to `max_power`.""" + return [k**i for i in range(1, max_power)] +def prod_combine(one, other): + """Multiply all pairs of elements from `one` and `other`.""" + return [a * b for a, b in itertools.product(one, other)] +# We have several sets of divisors, inspired by various "interesting" multiples the multipliers may compute. small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199] medium_primes = [211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397] large_primes = [401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997] @@ -390,4 +443,5 @@ divisor_map["all"] = list(sorted(set().union(*[v for v in divisor_map.values()]) def conf_interval(p: float, samples: int, alpha: float = 0.05) -> tuple[float, float]: + """Compute a confidence interval for a Binomial distribution.""" return proportion_confint(round(p*samples), samples, alpha, method="wilson") |
