aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJán Jančár2023-02-13 00:30:09 +0100
committerGitHub2023-02-13 00:30:09 +0100
commit14c04894992e3def1f0be08849cce8c8014cd96d (patch)
treecc16095f01eee9ac84a551800c1de59ce74609a5
parentd4ebdf7e7ed1867b49f353ade6c130b5dcbc3ed1 (diff)
parent82e5e4e75926eb9ca581594d0d4deade0c7ef703 (diff)
downloadpyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.tar.gz
pyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.tar.zst
pyecsca-14c04894992e3def1f0be08849cce8c8014cd96d.zip
Merge pull request #28 from J08nY/feature/perf
Improve performance of simulation
-rw-r--r--pyecsca/ec/context.py70
-rw-r--r--pyecsca/ec/formula.py25
-rw-r--r--pyecsca/ec/mod.py63
-rwxr-xr-xtest/ec/perf_formula.py6
-rw-r--r--test/ec/test_context.py15
-rw-r--r--test/ec/test_mod.py3
-rw-r--r--test/utils.py2
7 files changed, 85 insertions, 99 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py
index a0fd3c2..49c886d 100644
--- a/pyecsca/ec/context.py
+++ b/pyecsca/ec/context.py
@@ -14,7 +14,6 @@ A :py:class:`NullContext` does not trace any actions and is the default context.
"""
from abc import abstractmethod, ABC
from collections import OrderedDict
-from contextvars import ContextVar, Token
from copy import deepcopy
from typing import List, Optional, ContextManager, Any, Tuple, Sequence
@@ -31,12 +30,14 @@ class Action:
self.inside = False
def __enter__(self):
- getcontext().enter_action(self)
+ if current is not None:
+ current.enter_action(self)
self.inside = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- getcontext().exit_action(self)
+ if current is not None:
+ current.exit_action(self)
self.inside = False
@@ -64,10 +65,10 @@ class ResultAction(Action):
def __exit__(self, exc_type, exc_val, exc_tb):
if (
- not self._has_result
- and exc_type is None
- and exc_val is None
- and exc_tb is None
+ not self._has_result
+ and exc_type is None
+ and exc_val is None
+ and exc_tb is None
):
raise RuntimeError("Result unset on action exit")
super().__exit__(exc_type, exc_val, exc_tb)
@@ -166,17 +167,6 @@ class Context(ABC):
@public
-class NullContext(Context):
- """Context that does not trace any actions."""
-
- def enter_action(self, action: Action) -> None:
- pass # Nothing to enter as no action is traced.
-
- def exit_action(self, action: Action) -> None:
- pass # Nothing to exit as no action is traced.
-
-
-@public
class DefaultContext(Context):
"""Context that traces executions of actions in a tree."""
@@ -240,48 +230,22 @@ class PathContext(Context):
)
-_actual_context: ContextVar[Context] = ContextVar(
- "operational_context", default=NullContext()
-)
+current: Optional[Context] = None
class _ContextManager:
def __init__(self, new_context):
self.new_context = deepcopy(new_context)
- def __enter__(self) -> Context:
- self.token = setcontext(self.new_context)
- return self.new_context
+ def __enter__(self) -> Optional[Context]:
+ global current
+ self.old_context = current
+ current = self.new_context
+ return current
def __exit__(self, t, v, tb):
- resetcontext(self.token)
-
-
-@public
-def getcontext() -> Context:
- """Get the current thread/task context."""
- return _actual_context.get()
-
-
-@public
-def setcontext(ctx: Context) -> Token:
- """
- Set the current thread/task context.
-
- :param ctx: A context to set.
- :return: A token to restore previous context.
- """
- return _actual_context.set(ctx)
-
-
-@public
-def resetcontext(token: Token):
- """
- Reset the context to a previous value.
-
- :param token: A token to restore.
- """
- _actual_context.reset(token)
+ global current
+ current = self.old_context
@public
@@ -293,5 +257,5 @@ def local(ctx: Optional[Context] = None) -> ContextManager:
:return: A context manager.
"""
if ctx is None:
- ctx = getcontext()
+ ctx = current
return _ContextManager(ctx)
diff --git a/pyecsca/ec/formula.py b/pyecsca/ec/formula.py
index 994709c..ec8f9c0 100644
--- a/pyecsca/ec/formula.py
+++ b/pyecsca/ec/formula.py
@@ -1,6 +1,8 @@
"""Provides an abstract base class of a formula along with concrete instantiations."""
from abc import ABC, abstractmethod
from ast import parse, Expression
+from functools import cached_property
+
from astunparse import unparse
from itertools import product
from typing import List, Set, Any, ClassVar, MutableMapping, Tuple, Union, Dict
@@ -9,7 +11,8 @@ from pkg_resources import resource_stream
from public import public
from sympy import sympify, FF, symbols, Poly, Rational
-from .context import ResultAction, getcontext, NullContext
+from .context import ResultAction
+from . import context
from .error import UnsatisfiedAssumptionError, raise_unsatisified_assumption
from .mod import Mod, SymbolicMod
from .op import CodeOp, OpType
@@ -67,8 +70,6 @@ class FormulaAction(ResultAction):
self.output_points = []
def add_operation(self, op: CodeOp, value: Mod):
- if isinstance(getcontext(), NullContext):
- return
parents: List[Union[Mod, OpResult]] = []
for parent in {*op.variables, *op.parameters}:
if parent in self.intermediates:
@@ -79,8 +80,6 @@ class FormulaAction(ResultAction):
li.append(OpResult(op.result, value, op.operator, *parents))
def add_result(self, point: Any, **outputs: Mod):
- if isinstance(getcontext(), NullContext):
- return
for k in outputs:
self.outputs[k] = self.intermediates[k][-1]
self.output_points.append(point)
@@ -117,6 +116,10 @@ class Formula(ABC):
unified: bool
"""Whether the formula is specifies that it is unified."""
+ @cached_property
+ def assumptions_str(self):
+ return [unparse(assumption)[1:-2] for assumption in self.assumptions]
+
def __validate_params(self, field, params):
for key, value in params.items():
if not isinstance(value, Mod) or value.n != field:
@@ -141,8 +144,7 @@ class Formula(ABC):
# Validate assumptions and compute formula parameters.
# TODO: Should this also validate coordinate assumptions and compute their parameters?
is_symbolic = any(isinstance(x, SymbolicMod) for x in params.values())
- for assumption in self.assumptions:
- assumption_string = unparse(assumption)[1:-2]
+ for assumption, assumption_string in zip(self.assumptions, self.assumptions_str):
lhs, rhs = assumption_string.split(" == ")
if lhs in params:
# Handle an assumption check on value of input points.
@@ -220,7 +222,8 @@ class Formula(ABC):
self.__validate_params(field, params)
self.__validate_points(field, points, params)
- self.__validate_assumptions(field, params)
+ if self.assumptions:
+ self.__validate_assumptions(field, params)
# Execute the actual formula.
with FormulaAction(self, *points, **params) as action:
for op in self.code:
@@ -234,7 +237,8 @@ class Formula(ABC):
)
if not isinstance(op_result, Mod):
op_result = Mod(op_result, field)
- action.add_operation(op, op_result)
+ if context.current is not None:
+ action.add_operation(op, op_result)
params[op.result] = op_result
result = []
# Go over the outputs and construct the resulting points.
@@ -248,7 +252,8 @@ class Formula(ABC):
full_resulting[full_variable] = params[full_variable]
point = Point(self.coordinate_model, **resulting)
- action.add_result(point, **full_resulting)
+ if context.current is not None:
+ action.add_result(point, **full_resulting)
result.append(point)
return action.exit(tuple(result))
diff --git a/pyecsca/ec/mod.py b/pyecsca/ec/mod.py
index add7581..a43db53 100644
--- a/pyecsca/ec/mod.py
+++ b/pyecsca/ec/mod.py
@@ -10,7 +10,7 @@ dispatches to the implementation chosen by the runtime configuration of the libr
import random
import secrets
from functools import wraps, lru_cache
-from typing import Type, Dict, Any, Tuple
+from typing import Type, Dict, Any, Tuple, Union
from public import public
from sympy import Expr, FF
@@ -116,9 +116,8 @@ def _check(func):
def method(self, other):
if type(self) is not type(other):
other = self.__class__(other, self.n)
- else:
- if self.n != other.n:
- raise ValueError
+ elif self.n != other.n:
+ raise ValueError
return func(self, other)
return method
@@ -147,6 +146,7 @@ class Mod:
x: Any
n: Any
+ __slots__ = ("x", "n")
def __new__(cls, *args, **kwargs):
if cls != Mod:
@@ -263,6 +263,7 @@ class RawMod(Mod):
x: int
n: int
+ __slots__ = ("x", "n")
def __new__(cls, *args, **kwargs):
return object.__new__(cls)
@@ -367,6 +368,7 @@ _mod_classes["python"] = RawMod
@public
class Undefined(Mod):
"""A special undefined element."""
+ __slots__ = ("x", "n")
def __new__(cls, *args, **kwargs):
return object.__new__(cls)
@@ -472,6 +474,7 @@ class SymbolicMod(Mod):
x: Expr
n: int
+ __slots__ = ("x", "n")
def __new__(cls, *args, **kwargs):
return object.__new__(cls)
@@ -580,25 +583,30 @@ if has_gmp:
x: gmpy2.mpz
n: gmpy2.mpz
+ __slots__ = ("x", "n")
def __new__(cls, *args, **kwargs):
return object.__new__(cls)
- def __init__(self, x: int, n: int):
- self.x = gmpy2.mpz(x % n)
- self.n = gmpy2.mpz(n)
+ def __init__(self, x: Union[int, gmpy2.mpz], n: Union[int, gmpy2.mpz], ensure: bool = True):
+ if ensure:
+ self.n = gmpy2.mpz(n)
+ self.x = gmpy2.mpz(x % self.n)
+ else:
+ self.n = n
+ self.x = x
def inverse(self) -> "GMPMod":
if self.x == 0:
raise_non_invertible()
if self.x == 1:
- return GMPMod(1, self.n)
+ return GMPMod(gmpy2.mpz(1), self.n, ensure=False)
try:
res = gmpy2.invert(self.x, self.n)
except ZeroDivisionError:
raise_non_invertible()
- res = 0
- return GMPMod(res, self.n)
+ res = gmpy2.mpz(0)
+ return GMPMod(res, self.n, ensure=False)
def is_residue(self) -> bool:
if not _is_prime(self.n):
@@ -613,7 +621,7 @@ if has_gmp:
if not _is_prime(self.n):
raise NotImplementedError
if self.x == 0:
- return GMPMod(0, self.n)
+ return GMPMod(gmpy2.mpz(0), self.n, ensure=False)
if not self.is_residue():
raise_non_residue()
if self.n % 4 == 3:
@@ -624,12 +632,12 @@ if has_gmp:
q //= 2
s += 1
- z = 2
- while GMPMod(z, self.n).is_residue():
+ z = gmpy2.mpz(2)
+ while GMPMod(z, self.n, ensure=False).is_residue():
z += 1
m = s
- c = GMPMod(z, self.n) ** int(q)
+ c = GMPMod(z, self.n, ensure=False) ** int(q)
t = self ** int(q)
r_exp = (q + 1) // 2
r = self ** int(r_exp)
@@ -639,17 +647,32 @@ if has_gmp:
while not (t ** (2 ** i)) == 1:
i += 1
two_exp = m - (i + 1)
- b = c ** int(GMPMod(2, self.n) ** two_exp)
- m = int(GMPMod(i, self.n))
+ b = c ** int(GMPMod(gmpy2.mpz(2), self.n, ensure=False) ** two_exp)
+ m = int(GMPMod(gmpy2.mpz(i), self.n, ensure=False))
c = b ** 2
t *= c
r *= b
return r
@_check
+ def __add__(self, other) -> "GMPMod":
+ return GMPMod((self.x + other.x) % self.n, self.n, ensure=False)
+
+ @_check
+ def __sub__(self, other) -> "GMPMod":
+ return GMPMod((self.x - other.x) % self.n, self.n, ensure=False)
+
+ def __neg__(self) -> "GMPMod":
+ return GMPMod(self.n - self.x, self.n, ensure=False)
+
+ @_check
+ def __mul__(self, other) -> "GMPMod":
+ return GMPMod((self.x * other.x) % self.n, self.n, ensure=False)
+
+ @_check
def __divmod__(self, divisor) -> Tuple["GMPMod", "GMPMod"]:
q, r = gmpy2.f_divmod(self.x, divisor.x)
- return GMPMod(q, self.n), GMPMod(r, self.n)
+ return GMPMod(q, self.n, ensure=False), GMPMod(r, self.n, ensure=False)
def __bytes__(self):
return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")
@@ -677,11 +700,11 @@ if has_gmp:
if type(n) not in (int, gmpy2.mpz):
raise TypeError
if n == 0:
- return GMPMod(1, self.n)
+ return GMPMod(gmpy2.mpz(1), self.n, ensure=False)
if n < 0:
return self.inverse() ** (-n)
if n == 1:
- return GMPMod(self.x, self.n)
- return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n)
+ return GMPMod(self.x, self.n, ensure=False)
+ return GMPMod(gmpy2.powmod(self.x, gmpy2.mpz(n), self.n), self.n, ensure=False)
_mod_classes["gmp"] = GMPMod
diff --git a/test/ec/perf_formula.py b/test/ec/perf_formula.py
index b49daab..baa6347 100755
--- a/test/ec/perf_formula.py
+++ b/test/ec/perf_formula.py
@@ -31,7 +31,7 @@ def main(profiler, mod, operations, directory):
add = coords.formulas["add-2016-rcb"]
dbl = coords.formulas["dbl-2016-rcb"]
click.echo(
- f"Profiling {operations} {p256.curve.prime.bit_length()}-bit doubling formula executions..."
+ f"Profiling {operations} {p256.curve.prime.bit_length()}-bit doubling formula (dbl2016rcb) executions..."
)
one_point = p256.generator
with Profiler(
@@ -40,7 +40,7 @@ def main(profiler, mod, operations, directory):
for _ in range(operations):
one_point = dbl(p256.curve.prime, one_point, **p256.curve.parameters)[0]
click.echo(
- f"Profiling {operations} {p256.curve.prime.bit_length()}-bit addition formula executions..."
+ f"Profiling {operations} {p256.curve.prime.bit_length()}-bit addition formula (add2016rcb) executions..."
)
other_point = p256.generator
with Profiler(
@@ -54,7 +54,7 @@ def main(profiler, mod, operations, directory):
ecoords = ed25519.curve.coordinate_model
dblg = ecoords.formulas["mdbl-2008-hwcd"]
click.echo(
- f"Profiling {operations} {ed25519.curve.prime.bit_length()}-bit doubling formula executions (with assumption)..."
+ f"Profiling {operations} {ed25519.curve.prime.bit_length()}-bit doubling formula (mdbl2008hwcd) executions (with assumption)..."
)
eone_point = ed25519.generator
with Profiler(
diff --git a/test/ec/test_context.py b/test/ec/test_context.py
index 6691985..9cd74a3 100644
--- a/test/ec/test_context.py
+++ b/test/ec/test_context.py
@@ -3,12 +3,8 @@ from unittest import TestCase
from pyecsca.ec.context import (
local,
DefaultContext,
- NullContext,
- getcontext,
- setcontext,
- resetcontext,
Tree,
- PathContext,
+ PathContext
)
from pyecsca.ec.key_generation import KeyGeneration
from pyecsca.ec.params import get_params
@@ -68,18 +64,14 @@ class ContextTests(TestCase):
def test_null(self):
with local() as ctx:
self.mult.multiply(59)
- self.assertIsInstance(ctx, NullContext)
+ self.assertIs(ctx, None)
def test_default(self):
- token = setcontext(DefaultContext())
- self.addCleanup(resetcontext, token)
-
with local(DefaultContext()) as ctx:
result = self.mult.multiply(59)
self.assertEqual(len(ctx.actions), 1)
action = next(iter(ctx.actions.keys()))
self.assertIsInstance(action, ScalarMultiplicationAction)
- self.assertEqual(len(getcontext().actions), 0)
self.assertEqual(result, action.result)
def test_default_no_enter(self):
@@ -100,6 +92,5 @@ class ContextTests(TestCase):
self.mult.multiply(59)
str(default)
str(default.actions)
- with local(NullContext()) as null:
+ with local(None):
self.mult.multiply(59)
- str(null)
diff --git a/test/ec/test_mod.py b/test/ec/test_mod.py
index 7802e95..62022b0 100644
--- a/test/ec/test_mod.py
+++ b/test/ec/test_mod.py
@@ -176,6 +176,9 @@ class ModTests(TestCase):
"__hash__",
"__abstractmethods__",
"_abc_impl",
+ "__slots__",
+ "x",
+ "n"
):
continue
args = [5 for _ in range(meth.__code__.co_argcount - 1)]
diff --git a/test/utils.py b/test/utils.py
index d4893eb..e125813 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -62,7 +62,7 @@ class Profiler:
if self._state != "out":
raise ValueError
if self._prof_type == "py":
- print(self._prof.output_text(unicode=True, color=True))
+ print(self._prof.output_text(unicode=True, color=True, show_all=True))
else:
self._prof.print_stats("cumtime")