diff options
| author | Ján Jančár | 2023-02-13 00:30:09 +0100 |
|---|---|---|
| committer | GitHub | 2023-02-13 00:30:09 +0100 |
| commit | 14c04894992e3def1f0be08849cce8c8014cd96d (patch) | |
| tree | cc16095f01eee9ac84a551800c1de59ce74609a5 | |
| parent | d4ebdf7e7ed1867b49f353ade6c130b5dcbc3ed1 (diff) | |
| parent | 82e5e4e75926eb9ca581594d0d4deade0c7ef703 (diff) | |
| download | pyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.tar.gz pyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.tar.zst pyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.zip | |
Merge pull request #28 from J08nY/feature/perf
Improve performance of simulation
| -rw-r--r-- | pyecsca/ec/context.py | 70 | ||||
| -rw-r--r-- | pyecsca/ec/formula.py | 25 | ||||
| -rw-r--r-- | pyecsca/ec/mod.py | 63 | ||||
| -rwxr-xr-x | test/ec/perf_formula.py | 6 | ||||
| -rw-r--r-- | test/ec/test_context.py | 15 | ||||
| -rw-r--r-- | test/ec/test_mod.py | 3 | ||||
| -rw-r--r-- | test/utils.py | 2 |
7 files changed, 85 insertions, 99 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py index a0fd3c2..49c886d 100644 --- a/pyecsca/ec/context.py +++ b/pyecsca/ec/context.py @@ -14,7 +14,6 @@ A :py:class:`NullContext` does not trace any actions and is the default context. """ from abc import abstractmethod, ABC from collections import OrderedDict -from contextvars import ContextVar, Token from copy import deepcopy from typing import List, Optional, ContextManager, Any, Tuple, Sequence @@ -31,12 +30,14 @@ class Action: self.inside = False def __enter__(self): - getcontext().enter_action(self) + if current is not None: + current.enter_action(self) self.inside = True return self def __exit__(self, exc_type, exc_val, exc_tb): - getcontext().exit_action(self) + if current is not None: + current.exit_action(self) self.inside = False @@ -64,10 +65,10 @@ class ResultAction(Action): def __exit__(self, exc_type, exc_val, exc_tb): if ( - not self._has_result - and exc_type is None - and exc_val is None - and exc_tb is None + not self._has_result + and exc_type is None + and exc_val is None + and exc_tb is None ): raise RuntimeError("Result unset on action exit") super().__exit__(exc_type, exc_val, exc_tb) @@ -166,17 +167,6 @@ class Context(ABC): @public -class NullContext(Context): - """Context that does not trace any actions.""" - - def enter_action(self, action: Action) -> None: - pass # Nothing to enter as no action is traced. - - def exit_action(self, action: Action) -> None: - pass # Nothing to exit as no action is traced. - - -@public class DefaultContext(Context): """Context that traces executions of actions in a tree.""" @@ -240,48 +230,22 @@ class PathContext(Context): ) -_actual_context: ContextVar[Context] = ContextVar( - "operational_context", default=NullContext() -) +current: Optional[Context] = None class _ContextManager: def __init__(self, new_context): self.new_context = deepcopy(new_context) - def __enter__(self) -> Context: - self.token = setcontext(self.new_context) - return self.new_context + def __enter__(self) -> Optional[Context]: + global current + self.old_context = current + current = self.new_context + return current def __exit__(self, t, v, tb): - resetcontext(self.token) - - -@public -def getcontext() -> Context: - """Get the current thread/task context.""" - return _actual_context.get() - - -@public -def setcontext(ctx: Context) -> Token: - """ - Set the current thread/task context. - - :param ctx: A context to set. - :return: A token to restore previous context. - """ - return _actual_context.set(ctx) - - -@public -def resetcontext(token: Token): - """ - Reset the context to a previous value. - - :param token: A token to restore. - """ - _actual_context.reset(token) + global current + current = self.old_context @public @@ -293,5 +257,5 @@ def local(ctx: Optional[Context] = None) -> ContextManager: :return: A context manager. """ if ctx is None: - ctx = getcontext() + ctx = current return _ContextManager(ctx) diff --git a/pyecsca/ec/formula.py b/pyecsca/ec/formula.py index 994709c..ec8f9c0 100644 --- a/pyecsca/ec/formula.py +++ b/pyecsca/ec/formula.py @@ -1,6 +1,8 @@ """Provides an abstract base class of a formula along with concrete instantiations.""" from abc import ABC, abstractmethod from ast import parse, Expression +from functools import cached_property + from astunparse import unparse from itertools import product from typing import List, Set, Any, ClassVar, MutableMapping, Tuple, Union, Dict @@ -9,7 +11,8 @@ from pkg_resources import resource_stream from public import public from sympy import sympify, FF, symbols, Poly, Rational -from .context import ResultAction, getcontext, NullContext +from .context import ResultAction +from . import context from .error import UnsatisfiedAssumptionError, raise_unsatisified_assumption from .mod import Mod, SymbolicMod from .op import CodeOp, OpType @@ -67,8 +70,6 @@ class FormulaAction(ResultAction): self.output_points = [] def add_operation(self, op: CodeOp, value: Mod): - if isinstance(getcontext(), NullContext): - return parents: List[Union[Mod, OpResult]] = [] for parent in {*op.variables, *op.parameters}: if parent in self.intermediates: @@ -79,8 +80,6 @@ class FormulaAction(ResultAction): li.append(OpResult(op.result, value, op.operator, *parents)) def add_result(self, point: Any, **outputs: Mod): - if isinstance(getcontext(), NullContext): - return for k in outputs: self.outputs[k] = self.intermediates[k][-1] self.output_points.append(point) @@ -117,6 +116,10 @@ class Formula(ABC): unified: bool """Whether the formula is specifies that it is unified.""" + @cached_property + def assumptions_str(self): + return [unparse(assumption)[1:-2] for assumption in self.assumptions] + def __validate_params(self, field, params): for key, value in params.items(): if not isinstance(value, Mod) or value.n != field: @@ -141,8 +144,7 @@ class Formula(ABC): # Validate assumptions and compute formula parameters. # TODO: Should this also validate coordinate assumptions and compute their parameters? is_symbolic = any(isinstance(x, SymbolicMod) for x in params.values()) - for assumption in self.assumptions: - assumption_string = unparse(assumption)[1:-2] + for assumption, assumption_string in zip(self.assumptions, self.assumptions_str): lhs, rhs = assumption_string.split(" == ") if lhs in params: # Handle an assumption check on value of input points. @@ -220,7 +222,8 @@ class Formula(ABC): self.__validate_params(field, params) self.__validate_points(field, points, params) - self.__validate_assumptions(field, params) + if self.assumptions: + self.__validate_assumptions(field, params) # Execute the actual formula. with FormulaAction(self, *points, **params) as action: for op in self.code: @@ -234,7 +237,8 @@ class Formula(ABC): ) if not isinstance(op_result, Mod): op_result = Mod(op_result, field) - action.add_operation(op, op_result) + if context.current is not None: + action.add_operation(op, op_result) params[op.result] = op_result result = [] # Go over the outputs and construct the resulting points. @@ -248,7 +252,8 @@ class Formula(ABC): full_resulting[full_variable] = params[full_variable] point = Point(self.coordinate_model, **resulting) - action.add_result(point, **full_resulting) + if context.current is not None: + action.add_result(point, **full_resulting) result.append(point) return action.exit(tuple(result)) diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py index add7581..a43db53 100644 --- a/pyecsca/ec/mod.py +++ b/pyecsca/ec/mod.py @@ -10,7 +10,7 @@ dispatches to the implementation chosen by the runtime configuration of the libr import random import secrets from functools import wraps, lru_cache -from typing import Type, Dict, Any, Tuple +from typing import Type, Dict, Any, Tuple, Union from public import public from sympy import Expr, FF @@ -116,9 +116,8 @@ def _check(func): def method(self, other): if type(self) is not type(other): other = self.__class__(other, self.n) - else: - if self.n != other.n: - raise ValueError + elif self.n != other.n: + raise ValueError return func(self, other) return method @@ -147,6 +146,7 @@ class Mod: x: Any n: Any + __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): if cls != Mod: @@ -263,6 +263,7 @@ class RawMod(Mod): x: int n: int + __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) @@ -367,6 +368,7 @@ _mod_classes["python"] = RawMod @public class Undefined(Mod): """A special undefined element.""" + __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) @@ -472,6 +474,7 @@ class SymbolicMod(Mod): x: Expr n: int + __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) @@ -580,25 +583,30 @@ if has_gmp: x: gmpy2.mpz n: gmpy2.mpz + __slots__ = ("x", "n") def __new__(cls, *args, **kwargs): return object.__new__(cls) - def __init__(self, x: int, n: int): - self.x = gmpy2.mpz(x % n) - self.n = gmpy2.mpz(n) + def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz], ensure: bool = True): + if ensure: + self.n = gmpy2.mpz(n) + self.x = gmpy2.mpz(x % self.n) + else: + self.n = n + self.x = x def inverse(self) -> "GMPMod": if self.x == 0: raise_non_invertible() if self.x == 1: - return GMPMod(1, self.n) + return GMPMod(gmpy2.mpz(1), self.n, ensure=False) try: res = gmpy2.invert(self.x, self.n) except ZeroDivisionError: raise_non_invertible() - res = 0 - return GMPMod(res, self.n) + res = gmpy2.mpz(0) + return GMPMod(res, self.n, ensure=False) def is_residue(self) -> bool: if not _is_prime(self.n): @@ -613,7 +621,7 @@ if has_gmp: if not _is_prime(self.n): raise NotImplementedError if self.x == 0: - return GMPMod(0, self.n) + return GMPMod(gmpy2.mpz(0), self.n, ensure=False) if not self.is_residue(): raise_non_residue() if self.n % 4 == 3: @@ -624,12 +632,12 @@ if has_gmp: q //= 2 s += 1 - z = 2 - while GMPMod(z, self.n).is_residue(): + z = gmpy2.mpz(2) + while GMPMod(z, self.n, ensure=False).is_residue(): z += 1 m = s - c = GMPMod(z, self.n) ** int(q) + c = GMPMod(z, self.n, ensure=False) ** int(q) t = self ** int(q) r_exp = (q + 1) // 2 r = self ** int(r_exp) @@ -639,17 +647,32 @@ if has_gmp: while not (t ** (2 ** i)) == 1: i += 1 two_exp = m - (i + 1) - b = c ** int(GMPMod(2, self.n) ** two_exp) - m = int(GMPMod(i, self.n)) + b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp) + m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False)) c = b ** 2 t *= c r *= b return r @_check + def __add__(self, other) -> "GMPMod": + return GMPMod((self.x + other.x) % self.n, self.n, ensure=False) + + @_check + def __sub__(self, other) -> "GMPMod": + return GMPMod((self.x - other.x) % self.n, self.n, ensure=False) + + def __neg__(self) -> "GMPMod": + return GMPMod(self.n - self.x, self.n, ensure=False) + + @_check + def __mul__(self, other) -> "GMPMod": + return GMPMod((self.x * other.x) % self.n, self.n, ensure=False) + + @_check def __divmod__(self, divisor) -> Tuple["GMPMod", "GMPMod"]: q, r = gmpy2.f_divmod(self.x, divisor.x) - return GMPMod(q, self.n), GMPMod(r, self.n) + return GMPMod(q, self.n, ensure=False), GMPMod(r, self.n, ensure=False) def __bytes__(self): return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big") @@ -677,11 +700,11 @@ if has_gmp: if type(n) not in (int, gmpy2.mpz): raise TypeError if n == 0: - return GMPMod(1, self.n) + return GMPMod(gmpy2.mpz(1), self.n, ensure=False) if n < 0: return self.inverse() ** (-n) if n == 1: - return GMPMod(self.x, self.n) - return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n) + return GMPMod(self.x, self.n, ensure=False) + return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n, ensure=False) _mod_classes["gmp"] = GMPMod diff --git a/test/ec/perf_formula.py b/test/ec/perf_formula.py index b49daab..baa6347 100755 --- a/test/ec/perf_formula.py +++ b/test/ec/perf_formula.py @@ -31,7 +31,7 @@ def main(profiler, mod, operations, directory): add = coords.formulas["add-2016-rcb"] dbl = coords.formulas["dbl-2016-rcb"] click.echo( - f"Profiling {operations} {p256.curve.prime.bit_length()}-bit doubling formula executions..." + f"Profiling {operations} {p256.curve.prime.bit_length()}-bit doubling formula (dbl2016rcb) executions..." ) one_point = p256.generator with Profiler( @@ -40,7 +40,7 @@ def main(profiler, mod, operations, directory): for _ in range(operations): one_point = dbl(p256.curve.prime, one_point, **p256.curve.parameters)[0] click.echo( - f"Profiling {operations} {p256.curve.prime.bit_length()}-bit addition formula executions..." + f"Profiling {operations} {p256.curve.prime.bit_length()}-bit addition formula (add2016rcb) executions..." ) other_point = p256.generator with Profiler( @@ -54,7 +54,7 @@ def main(profiler, mod, operations, directory): ecoords = ed25519.curve.coordinate_model dblg = ecoords.formulas["mdbl-2008-hwcd"] click.echo( - f"Profiling {operations} {ed25519.curve.prime.bit_length()}-bit doubling formula executions (with assumption)..." + f"Profiling {operations} {ed25519.curve.prime.bit_length()}-bit doubling formula (mdbl2008hwcd) executions (with assumption)..." ) eone_point = ed25519.generator with Profiler( diff --git a/test/ec/test_context.py b/test/ec/test_context.py index 6691985..9cd74a3 100644 --- a/test/ec/test_context.py +++ b/test/ec/test_context.py @@ -3,12 +3,8 @@ from unittest import TestCase from pyecsca.ec.context import ( local, DefaultContext, - NullContext, - getcontext, - setcontext, - resetcontext, Tree, - PathContext, + PathContext ) from pyecsca.ec.key_generation import KeyGeneration from pyecsca.ec.params import get_params @@ -68,18 +64,14 @@ class ContextTests(TestCase): def test_null(self): with local() as ctx: self.mult.multiply(59) - self.assertIsInstance(ctx, NullContext) + self.assertIs(ctx, None) def test_default(self): - token = setcontext(DefaultContext()) - self.addCleanup(resetcontext, token) - with local(DefaultContext()) as ctx: result = self.mult.multiply(59) self.assertEqual(len(ctx.actions), 1) action = next(iter(ctx.actions.keys())) self.assertIsInstance(action, ScalarMultiplicationAction) - self.assertEqual(len(getcontext().actions), 0) self.assertEqual(result, action.result) def test_default_no_enter(self): @@ -100,6 +92,5 @@ class ContextTests(TestCase): self.mult.multiply(59) str(default) str(default.actions) - with local(NullContext()) as null: + with local(None): self.mult.multiply(59) - str(null) diff --git a/test/ec/test_mod.py b/test/ec/test_mod.py index 7802e95..62022b0 100644 --- a/test/ec/test_mod.py +++ b/test/ec/test_mod.py @@ -176,6 +176,9 @@ class ModTests(TestCase): "__hash__", "__abstractmethods__", "_abc_impl", + "__slots__", + "x", + "n" ): continue args = [5 for _ in range(meth.__code__.co_argcount - 1)] diff --git a/test/utils.py b/test/utils.py index d4893eb..e125813 100644 --- a/test/utils.py +++ b/test/utils.py @@ -62,7 +62,7 @@ class Profiler: if self._state != "out": raise ValueError if self._prof_type == "py": - print(self._prof.output_text(unicode=True, color=True)) + print(self._prof.output_text(unicode=True, color=True, show_all=True)) else: self._prof.print_stats("cumtime") |
