diff options
| author | J08nY | 2025-07-31 13:42:59 +0200 |
|---|---|---|
| committer | J08nY | 2025-07-31 13:42:59 +0200 |
| commit | 8b4ec7a64a4cd62e3d70a70fdcaeebbf91e1cb73 (patch) | |
| tree | 8a9e511cafe12b77b12560e48c91ef78bd16cbf8 | |
| parent | cf38aed05f3dc15bd1ed375859e67526f6fcb079 (diff) | |
| download | ECTester-8b4ec7a64a4cd62e3d70a70fdcaeebbf91e1cb73.tar.gz ECTester-8b4ec7a64a4cd62e3d70a70fdcaeebbf91e1cb73.tar.zst ECTester-8b4ec7a64a4cd62e3d70a70fdcaeebbf91e1cb73.zip | |
| -rw-r--r-- | analysis/scalarmults/common.py | 122 | ||||
| -rw-r--r-- | analysis/scalarmults/simulate.ipynb | 270 |
2 files changed, 290 insertions, 102 deletions
diff --git a/analysis/scalarmults/common.py b/analysis/scalarmults/common.py index 2577c8c..eda6f5f 100644 --- a/analysis/scalarmults/common.py +++ b/analysis/scalarmults/common.py @@ -1,16 +1,101 @@ import itertools from datetime import timedelta +from enum import Enum +from operator import itemgetter from dataclasses import dataclass from functools import partial, cached_property, total_ordering -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Union, Literal from statsmodels.stats.proportion import proportion_confint +from pyecsca.sca.re.rpa import MultipleContext from pyecsca.ec.mult import * +from pyecsca.ec.point import Point from pyecsca.ec.countermeasures import GroupScalarRandomization, AdditiveSplitting, MultiplicativeSplitting, EuclideanSplitting, BrumleyTuveri +def check_equal_multiples(k, l, q): + return (k % q) == (l % q) + + +def check_divides(k, l, q): + return (k % q == 0) or (l % q == 0) + + +def check_half_add(k, l, q): + return (q % 2 == 0) and ((k+l) % (q//2)) == 0 + + +def check_affine(k, q): + return k % q == 0 + + +def check_any(*checks, q=None): + def check_func(k, l): + for check in checks: + if check(k, l, q): + return True + return False + return check_func + + +checks_add = { + "equal_multiples": check_equal_multiples, + "divides": check_divides, + "half_add": check_half_add +} + +checks_affine = { + "affine": check_affine +} + + +@dataclass(frozen=True) +@total_ordering +class ErrorModel: + checks: set[str] + check_condition: Union[Literal["all"], Literal["necessary"]] + precomp_to_affine: bool + + def __init__(self, checks: set[str], check_condition: Union[Literal["all"], Literal["necessary"]], precomp_to_affine: bool): + for check in checks: + if check not in checks_add: + raise ValueError(f"Unknown check: {check}") + checks = set(checks) + checks.add("affine") # always done in our model + object.__setattr__(self, "checks", checks) + if check_condition not in ("all", "necessary"): + raise ValueError("Wrong check_condition") + object.__setattr__(self, "check_condition", check_condition) + object.__setattr__(self, "precomp_to_affine", precomp_to_affine) + + def check_add(self, q): + if self.checks == {"affine"}: + return lambda k, l: False + return check_any(*map(lambda name: checks_add[name], filter(lambda check: check in checks_add, self.checks)), q=q) + + def check_affine(self, q): + return partial(check_affine, q=q) + + def __lt__(self, other): + if not isinstance(other, ErrorModel): + return NotImplemented + return str(self) < str(other) + + def __str__(self): + cs = [] + if "equal_multiples" in self.checks: + cs.append("em") + if "divides" in self.checks: + cs.append("d") + if "half_add" in self.checks: + cs.append("ha") + if "affine" in self.checks: + cs.append("a") + precomp = "+pre" if self.precomp_to_affine else "" + return f"({','.join(cs)}+{self.check_condition}{precomp})" + @dataclass(frozen=True) @total_ordering @@ -19,6 +104,7 @@ class MultIdent: args: list[Any] kwargs: dict[str, Any] countermeasure: Optional[str] = None + error_model: Optional[ErrorModel] = None def __init__(self, klass: Type[ScalarMultiplier], *args, **kwargs): object.__setattr__(self, "klass", klass) @@ -26,6 +112,9 @@ class MultIdent: if kwargs is not None and "countermeasure" in kwargs: object.__setattr__(self, "countermeasure", kwargs["countermeasure"]) del kwargs["countermeasure"] + if kwargs is not None and "error_model" in kwargs: + object.__setattr__(self, "error_model", kwargs["error_model"]) + del kwargs["error_model"] object.__setattr__(self, "kwargs", kwargs if kwargs is not None else {}) @cached_property @@ -49,6 +138,11 @@ class MultIdent: raise ValueError(f"Unknown countermeasure: {countermeasure}") return MultIdent(self.klass, *self.args, **self.kwargs, countermeasure=countermeasure) + def with_error_model(self, error_model: ErrorModel | None): + if not (isinstance(error_model, ErrorModel) or error_model is None): + raise ValueError("Unknown error model.") + return MultIdent(self.klass, *self.args, **self.kwargs, error_model=error_model) + def __str__(self): name = self.klass.__name__.replace("Multiplier", "") args = ("_" + ",".join(list(map(str, self.args)))) if self.args else "" @@ -57,7 +151,8 @@ class MultIdent: "width": "w"} kwargs = ("_" + ",".join(f"{kwmap.get(k, k)}:{v.name if isinstance(v, Enum) else str(v)}" for k,v in self.kwargs.items())) if self.kwargs else "" countermeasure = f"+{self.countermeasure}" if self.countermeasure is not None else "" - return f"{name}{args}{kwargs}{countermeasure}" + error_model = f"+{self.error_model}" if self.error_model is not None else "" + return f"{name}{args}{kwargs}{countermeasure}{error_model}" def __lt__(self, other): if not isinstance(other, MultIdent): @@ -68,15 +163,14 @@ class MultIdent: return str(self) def __hash__(self): - return hash((self.klass, self.countermeasure, tuple(self.args), tuple(self.kwargs.keys()), tuple(self.kwargs.values()))) + return hash((self.klass, self.countermeasure, self.error_model, tuple(self.args), tuple(self.kwargs.keys()), tuple(self.kwargs.values()))) @dataclass class MultResults: - multiplications: list[set[int]] + multiplications: list[tuple[MultipleContext, Point]] samples: int duration: Optional[float] = None - kind: Optional[str] = None def merge(self, other: "MultResults"): self.multiplications.extend(other.multiplications) @@ -93,8 +187,7 @@ class MultResults: def __str__(self): duration = timedelta(seconds=int(self.duration)) if self.duration is not None else "" - kind = self.kind if self.kind is not None else "" - return f"MultResults({self.samples},{duration},{kind})" + return f"MultResults({self.samples},{duration})" def __repr__(self): return str(self) @@ -103,8 +196,8 @@ class MultResults: @dataclass class ProbMap: probs: dict[int, float] + divisors_hash: bytes samples: int - kind: Optional[str] = None def __len__(self): return len(self.probs) @@ -113,7 +206,7 @@ class ProbMap: yield from self.probs def __getitem__(self, i): - return self.probs[i] + return self.probs[i] if i in self.probs else 0.0 def keys(self): return self.probs.keys() @@ -128,8 +221,7 @@ class ProbMap: self.probs = {k:v for k, v in self.probs.items() if k in divisors} def merge(self, other: "ProbMap") -> None: - if self.kind != other.kind: - raise ValueError("Merging ProbMaps of different kinds leads to unexpected results.") + # TODO: This may not be correct now that ProbMaps may not store zero probability items new_keys = set(self.keys()).union(other.keys()) result = {} for key in new_keys: @@ -145,8 +237,7 @@ class ProbMap: def enrich(self, other: "ProbMap") -> None: if self.samples != other.samples: raise ValueError("Enriching can only work on equal amount of samples (same run, different divisors)") - if self.kind != other.kind: - raise ValueError("Enriching ProbMaps of different kinds leads to unexpected results.") + # TODO: Check distinct divisors. self.probs.update(other.probs) @@ -237,6 +328,11 @@ def powers_of(k, max_power=20): def prod_combine(one, other): return [a * b for a, b in itertools.product(one, other)] +def powerset(iterable): + s = list(iterable) + return map(set, itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))) + + small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199] medium_primes = [211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397] large_primes = [401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997] diff --git a/analysis/scalarmults/simulate.ipynb b/analysis/scalarmults/simulate.ipynb index 93d016a..b5b57bb 100644 --- a/analysis/scalarmults/simulate.ipynb +++ b/analysis/scalarmults/simulate.ipynb @@ -15,16 +15,25 @@ "metadata": {}, "outputs": [], "source": [ - "import pickle\n", "import itertools\n", "import glob\n", + "import pickle\n", + "import random\n", + "import re\n", + "import hashlib\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html.\",\n", + " category=UserWarning\n", + ")\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from collections import Counter\n", - "\n", "from pathlib import Path\n", "from random import randint, randbytes\n", "from typing import Type, Any\n", @@ -33,7 +42,9 @@ "\n", "from pyecsca.ec.params import DomainParameters, get_params\n", "from pyecsca.ec.mult import *\n", - "from pyecsca.sca.re.rpa import multiples_computed\n", + "from pyecsca.ec.mod import mod\n", + "from pyecsca.sca.re.rpa import multiple_graph\n", + "from pyecsca.sca.re.epa import graph_to_check_inputs, evaluate_checks\n", "from pyecsca.misc.utils import TaskExecutor\n", "\n", "from common import *" @@ -50,78 +61,38 @@ { "cell_type": "code", "execution_count": null, - "id": "a660e3ac-401b-47a0-92de-55afe63c420a", + "id": "3463a7bd-34d8-458b-8ceb-dddf99de21dc", "metadata": {}, "outputs": [], "source": [ - "print(len(all_mults))" + "def silence():\n", + " import warnings\n", + " warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html.\",\n", + " category=UserWarning\n", + " )\n", + "silence()" ] }, { "cell_type": "code", "execution_count": null, - "id": "a95b27fc-96a9-41b5-9972-dc8386ed386d", + "id": "a660e3ac-401b-47a0-92de-55afe63c420a", "metadata": {}, "outputs": [], "source": [ - "print(len(all_mults_with_ctr))" + "print(len(all_mults))" ] }, { "cell_type": "code", "execution_count": null, - "id": "07bc266d-35eb-4f6d-bdba-e9f6f66827f1", + "id": "a95b27fc-96a9-41b5-9972-dc8386ed386d", "metadata": {}, "outputs": [], "source": [ - "# Needs imports on the inside to be spawn enabled to save memory.\n", - "\n", - "def get_general_multiples(bits: int, samples: int = 1000) -> MultResults:\n", - " from random import randint\n", - " results = []\n", - " for _ in range(samples):\n", - " big_scalar = randint(1, 2**bits)\n", - " results.append({big_scalar})\n", - " return MultResults(results, samples)\n", - "\n", - "def get_general_n_multiples(bits: int, n: int, samples: int = 1000) -> MultResults:\n", - " from random import randint\n", - " results = []\n", - " for _ in range(samples):\n", - " smult = set()\n", - " for i in range(n):\n", - " b = randint(1,256)\n", - " smult.add(randint(2**b,2**(b+1)))\n", - " results.append(smult)\n", - " return MultResults(results, samples)\n", - "\n", - "def get_small_scalar_multiples(mult: MultIdent,\n", - " params: DomainParameters,\n", - " bits: int,\n", - " samples: int = 100,\n", - " use_init: bool = True,\n", - " use_multiply: bool = True,\n", - " seed: bytes | None = None,\n", - " kind: str = \"precomp+necessary\") -> MultResults:\n", - " from pyecsca.sca.re.rpa import multiples_computed\n", - " import random\n", - " \n", - " results = []\n", - " if seed is not None:\n", - " random.seed(seed)\n", - "\n", - " # If no countermeasure is used, we have fully random scalars.\n", - " # Otherwise, fix one per chunk.\n", - " if mult.countermeasure is None:\n", - " scalars = [random.randint(1, 2**bits) for _ in range(samples)]\n", - " else:\n", - " one = random.randint(1, 2**bits)\n", - " scalars = [one for _ in range(samples)]\n", - "\n", - " for scalar in scalars:\n", - " # Use a list for less memory usage.\n", - " results.append(list(multiples_computed(scalar, params, mult.klass, mult.partial, use_init, use_multiply, kind=kind)))\n", - " return MultResults(results, samples)" + "print(len(all_mults_with_ctr))" ] }, { @@ -152,6 +123,64 @@ ] }, { + "cell_type": "code", + "execution_count": null, + "id": "07bc266d-35eb-4f6d-bdba-e9f6f66827f1", + "metadata": {}, + "outputs": [], + "source": [ + "def simulate_multiples(mult: MultIdent,\n", + " params: DomainParameters,\n", + " bits: int,\n", + " samples: int = 100,\n", + " use_init: bool = True,\n", + " use_multiply: bool = True,\n", + " seed: bytes | None = None) -> MultResults:\n", + " results = []\n", + " if seed is not None:\n", + " random.seed(seed)\n", + "\n", + " # If no countermeasure is used, we have fully random scalars.\n", + " # Otherwise, fix one per chunk.\n", + " if mult.countermeasure is None:\n", + " scalars = [random.randint(1, 2**bits) for _ in range(samples)]\n", + " else:\n", + " one = random.randint(1, 2**bits)\n", + " scalars = [one for _ in range(samples)]\n", + "\n", + " for scalar in scalars:\n", + " results.append(multiple_graph(scalar, params, mult.klass, mult.partial, use_init, use_multiply))\n", + " return MultResults(results, samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64799c16-8113-4eff-81de-6a3e547eb5c5", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_multiples(mult: MultIdent, res: MultResults, divisors: set[int]):\n", + " errors = {divisor: 0 for divisor in divisors}\n", + " samples = len(res)\n", + " divisors_hash = hashlib.blake2b(str(sorted(divisors)).encode(), digest_size=8).digest()\n", + " for ctx, out in res:\n", + " check_inputs = graph_to_check_inputs(ctx, out,\n", + " check_condition=mult.error_model.check_condition,\n", + " precomp_to_affine=mult.error_model.precomp_to_affine)\n", + " for q in divisors:\n", + " error = evaluate_checks(check_funcs={\"add\": mult.error_model.check_add(q), \"affine\": mult.error_model.check_affine(q)},\n", + " check_inputs=check_inputs)\n", + " errors[q] += error\n", + " # Make probmaps smaller. Do not store zero probabilities.\n", + " probs = {}\n", + " for q, error in errors.items():\n", + " if error != 0:\n", + " probs[q] = error / samples\n", + " return ProbMap(probs, divisors_hash, samples)" + ] + }, + { "cell_type": "markdown", "id": "3aaf712e-5b97-4390-8dd4-e1db1dfe36a2", "metadata": {}, @@ -164,50 +193,113 @@ "cell_type": "code", "execution_count": null, "id": "84359084-4116-436c-92cd-d43fdfeca842", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ - "multiples_mults = {}\n", "chunk_id = randbytes(4).hex()\n", - "with TaskExecutor(max_workers=num_workers, mp_context=spawn_context) as pool, enable_spawn(get_small_scalar_multiples) as target:\n", + "with TaskExecutor(max_workers=num_workers, initializer=silence) as pool:\n", " for mult in selected_mults:\n", " for countermeasure in (None, \"gsr\", \"additive\", \"multiplicative\", \"euclidean\", \"bt\"):\n", - " mwc = mult.with_countermeasure(countermeasure)\n", - " pool.submit_task(mwc,\n", - " target,\n", - " mwc, params, bits, samples, seed=chunk_id, kind=kind, use_init=use_init, use_multiply=use_multiply)\n", - " for mult, future in tqdm(pool.as_completed(), desc=\"Computing small scalar distributions.\", total=len(pool.tasks)):\n", - " print(f\"Got {mult}.\")\n", - " if error := future.exception():\n", - " print(\"Error!\", error)\n", - " continue\n", - " res = future.result()\n", - " if mult not in multiples_mults:\n", - " multiples_mults[mult] = res\n", - " else:\n", - " # Accumulate\n", - " multiples_mults[mult].merge(res)\n", - "\n", - " # Handle the enable_spawn trick that messes up class modules.\n", - " for k, v in multiples_mults.items():\n", - " v.__class__ = MultResults\n", - " v.__module__ = \"common\"\n", - "\n", - "# Save\n", - "with open(f\"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk{chunk_id}.pickle\",\"wb\") as h:\n", - " for mult, res in multiples_mults.items():\n", - " pickle.dump((mult, res), h)" + " full = mult.with_countermeasure(countermeasure)\n", + " pool.submit_task(full,\n", + " simulate_multiples,\n", + " full, params, bits, samples, seed=chunk_id, use_init=use_init, use_multiply=use_multiply)\n", + " with open(f\"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk{chunk_id}.pickle\",\"wb\") as h:\n", + " for mult, future in tqdm(pool.as_completed(), desc=\"Computing multiple graphs.\", total=len(pool.tasks)):\n", + " print(f\"Got {mult}.\")\n", + " if error := future.exception():\n", + " print(\"Error!\", error)\n", + " continue\n", + " res = future.result()\n", + " pickle.dump((mult, res), h)" ] }, { "cell_type": "markdown", - "id": "b4471a1d-fdc3-4be7-bd61-5ddd22180b41", + "id": "44120f28-ae4a-42e3-befb-ebc487d51f9e", "metadata": {}, "source": [ - "### Load\n", - "**Beware**, the following load with try to load all chunks into memory, that will be very large.\n", + "## Process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbab8333-b8f1-4890-b38a-7bb34f5ffb02", + "metadata": {}, + "outputs": [], + "source": [ + "with TaskExecutor(max_workers=num_workers, initializer=silence) as pool:\n", + " for fname in glob.glob(f\"multiples_{bits}_{'init' if use_init else 'noinit'}_{'mult' if use_multiply else 'nomult'}_chunk*.pickle\"):\n", + " match = re.match(\"multiples_[0-9]+_(?P<init>(?:no)?init)_(?P<mult>(?:no)?mult)_chunk(?P<id>[0-9a-f]+).pickle\", fname)\n", + " use_init = match.group(\"init\") == \"init\"\n", + " use_multiply = match.group(\"mult\") == \"mult\"\n", + " chunk_id = match.group(\"id\")\n", + " multiples_mults = {} \n", + " with open(fname, \"rb\") as f:\n", + " bar = tqdm(total=len(all_mults_with_ctr), desc=f\"Loading chunk {chunk_id}.\")\n", + " while True:\n", + " try:\n", + " mult, vals = pickle.load(f)\n", + " bar.update(1)\n", + " if mult not in multiples_mults:\n", + " multiples_mults[mult] = vals\n", + " else:\n", + " multiples_mults[mult].merge(vals)\n", + " except EOFError:\n", + " break\n", + " for mult, res in multiples_mults.items():\n", + " for checks in powerset(checks_add):\n", + " for precomp_to_affine in (True, False):\n", + " for check_condition in (\"all\", \"necessary\"):\n", + " error_model = ErrorModel(checks, check_condition=check_condition, precomp_to_affine=precomp_to_affine)\n", + " full = mult.with_error_model(error_model)\n", + " pool.submit_task(full,\n", + " evaluate_multiples,\n", + " full, res, divisor_map[\"all\"])\n", + " fname = f\"probs_{use_init}_{use_mult}_chunk{chunk_id}.pickle\"\n", + " with open(fname, \"wb\") as f:\n", + " for full, future in tqdm(pool.as_completed(), desc=\"Computing errors.\", total=len(pool.tasks)):\n", + " print(f\"Got {full}.\")\n", + " if error := future.exception():\n", + " print(\"Error!\", error)\n", + " continue\n", + " res = future.result()\n", + " pickle.dump((full, res), f)" + ] + }, + { + "cell_type": "markdown", + "id": "228922dc-67bf-481a-9f08-4084695e2059", + "metadata": {}, + "source": [ + "## Misc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cdaad25-80d2-4574-9cfb-9d93e55e90d6", + "metadata": {}, + "outputs": [], + "source": [ + "from pyinstrument import Profiler as PyProfiler\n", + "mult = next(iter(multiples_mults))\n", + "res = multiples_mults[mult]\n", + "\n", "\n", - "You probably dont want to run this." + "for checks in powerset(checks_add):\n", + " for precomp_to_affine in (True, False):\n", + " for check_condition in (\"all\", \"necessary\"):\n", + " error_model = ErrorModel(checks, check_condition=check_condition, precomp_to_affine=precomp_to_affine)\n", + " full = mult.with_error_model(error_model)\n", + " print(full)\n", + " #with PyProfiler() as prof:\n", + " probmap = evaluate_multiples(full, res, divisor_map[\"all\"])\n", + " #print(prof.output_text(unicode=True, color=True))\n", + " #print(probmap)" ] }, { @@ -234,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11b447f2-71ab-417e-a856-1724788cfc91", + "id": "daba5215-fef8-4c8a-8d7d-1af49edffa7b", "metadata": {}, "outputs": [], "source": [] |
