aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2023-02-12 19:05:28 +0100
committerJ08nY2023-02-12 19:05:28 +0100
commit7421fce192b581d732eabf4b2948bd8546b4afea (patch)
tree486ccd2159e049a96607cef832841f463ef6e56c
parent7c3b286a78a294a0eb6cd052682b9ee92420fa4d (diff)
downloadpyecsca-7421fce192b581d732eabf4b2948bd8546b4afea.tar.gz
pyecsca-7421fce192b581d732eabf4b2948bd8546b4afea.tar.zst
pyecsca-7421fce192b581d732eabf4b2948bd8546b4afea.zip
Get rid of getcontext/setcontext functions.
-rw-r--r--pyecsca/ec/context.py72
-rw-r--r--pyecsca/ec/formula.py13
-rw-r--r--test/ec/test_context.py15
-rw-r--r--test/sca/test_rpa.py1
-rw-r--r--test/utils.py2
5 files changed, 28 insertions, 75 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py
index 691ae71..49c886d 100644
--- a/pyecsca/ec/context.py
+++ b/pyecsca/ec/context.py
@@ -14,7 +14,6 @@ A :py:class:`NullContext` does not trace any actions and is the default context.
"""
from abc import abstractmethod, ABC
from collections import OrderedDict
-from contextvars import ContextVar, Token
from copy import deepcopy
from typing import List, Optional, ContextManager, Any, Tuple, Sequence
@@ -31,12 +30,14 @@ class Action:
self.inside = False
def __enter__(self):
- getcontext().enter_action(self)
+ if current is not None:
+ current.enter_action(self)
self.inside = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- getcontext().exit_action(self)
+ if current is not None:
+ current.exit_action(self)
self.inside = False
@@ -64,10 +65,10 @@ class ResultAction(Action):
def __exit__(self, exc_type, exc_val, exc_tb):
if (
- not self._has_result
- and exc_type is None
- and exc_val is None
- and exc_tb is None
+ not self._has_result
+ and exc_type is None
+ and exc_val is None
+ and exc_tb is None
):
raise RuntimeError("Result unset on action exit")
super().__exit__(exc_type, exc_val, exc_tb)
@@ -166,17 +167,6 @@ class Context(ABC):
@public
-class NullContext(Context):
- """Context that does not trace any actions."""
-
- def enter_action(self, action: Action) -> None:
- pass # Nothing to enter as no action is traced.
-
- def exit_action(self, action: Action) -> None:
- pass # Nothing to exit as no action is traced.
-
-
-@public
class DefaultContext(Context):
"""Context that traces executions of actions in a tree."""
@@ -240,50 +230,22 @@ class PathContext(Context):
)
-_actual_context: Context = NullContext()
+current: Optional[Context] = None
class _ContextManager:
def __init__(self, new_context):
self.new_context = deepcopy(new_context)
- def __enter__(self) -> Context:
- self.token = setcontext(self.new_context)
- return self.new_context
+ def __enter__(self) -> Optional[Context]:
+ global current
+ self.old_context = current
+ current = self.new_context
+ return current
def __exit__(self, t, v, tb):
- resetcontext(self.token)
-
-
-@public
-def getcontext() -> Context:
- """Get the current thread/task context."""
- return _actual_context
-
-
-@public
-def setcontext(ctx: Context) -> Token:
- """
- Set the current thread/task context.
-
- :param ctx: A context to set.
- :return: A token to restore previous context.
- """
- global _actual_context
- old = _actual_context
- _actual_context = ctx
- return old
-
-
-@public
-def resetcontext(token: Token):
- """
- Reset the context to a previous value.
-
- :param token: A token to restore.
- """
- global _actual_context
- _actual_context = token
+ global current
+ current = self.old_context
@public
@@ -295,5 +257,5 @@ def local(ctx: Optional[Context] = None) -> ContextManager:
:return: A context manager.
"""
if ctx is None:
- ctx = getcontext()
+ ctx = current
return _ContextManager(ctx)
diff --git a/pyecsca/ec/formula.py b/pyecsca/ec/formula.py
index 994709c..7acdd5a 100644
--- a/pyecsca/ec/formula.py
+++ b/pyecsca/ec/formula.py
@@ -9,7 +9,8 @@ from pkg_resources import resource_stream
from public import public
from sympy import sympify, FF, symbols, Poly, Rational
-from .context import ResultAction, getcontext, NullContext
+from .context import ResultAction
+from . import context
from .error import UnsatisfiedAssumptionError, raise_unsatisified_assumption
from .mod import Mod, SymbolicMod
from .op import CodeOp, OpType
@@ -67,8 +68,6 @@ class FormulaAction(ResultAction):
self.output_points = []
def add_operation(self, op: CodeOp, value: Mod):
- if isinstance(getcontext(), NullContext):
- return
parents: List[Union[Mod, OpResult]] = []
for parent in {*op.variables, *op.parameters}:
if parent in self.intermediates:
@@ -79,8 +78,6 @@ class FormulaAction(ResultAction):
li.append(OpResult(op.result, value, op.operator, *parents))
def add_result(self, point: Any, **outputs: Mod):
- if isinstance(getcontext(), NullContext):
- return
for k in outputs:
self.outputs[k] = self.intermediates[k][-1]
self.output_points.append(point)
@@ -234,7 +231,8 @@ class Formula(ABC):
)
if not isinstance(op_result, Mod):
op_result = Mod(op_result, field)
- action.add_operation(op, op_result)
+ if context.current is not None:
+ action.add_operation(op, op_result)
params[op.result] = op_result
result = []
# Go over the outputs and construct the resulting points.
@@ -248,7 +246,8 @@ class Formula(ABC):
full_resulting[full_variable] = params[full_variable]
point = Point(self.coordinate_model, **resulting)
- action.add_result(point, **full_resulting)
+ if context.current is not None:
+ action.add_result(point, **full_resulting)
result.append(point)
return action.exit(tuple(result))
diff --git a/test/ec/test_context.py b/test/ec/test_context.py
index 6691985..9cd74a3 100644
--- a/test/ec/test_context.py
+++ b/test/ec/test_context.py
@@ -3,12 +3,8 @@ from unittest import TestCase
from pyecsca.ec.context import (
local,
DefaultContext,
- NullContext,
- getcontext,
- setcontext,
- resetcontext,
Tree,
- PathContext,
+ PathContext
)
from pyecsca.ec.key_generation import KeyGeneration
from pyecsca.ec.params import get_params
@@ -68,18 +64,14 @@ class ContextTests(TestCase):
def test_null(self):
with local() as ctx:
self.mult.multiply(59)
- self.assertIsInstance(ctx, NullContext)
+ self.assertIs(ctx, None)
def test_default(self):
- token = setcontext(DefaultContext())
- self.addCleanup(resetcontext, token)
-
with local(DefaultContext()) as ctx:
result = self.mult.multiply(59)
self.assertEqual(len(ctx.actions), 1)
action = next(iter(ctx.actions.keys()))
self.assertIsInstance(action, ScalarMultiplicationAction)
- self.assertEqual(len(getcontext().actions), 0)
self.assertEqual(result, action.result)
def test_default_no_enter(self):
@@ -100,6 +92,5 @@ class ContextTests(TestCase):
self.mult.multiply(59)
str(default)
str(default.actions)
- with local(NullContext()) as null:
+ with local(None):
self.mult.multiply(59)
- str(null)
diff --git a/test/sca/test_rpa.py b/test/sca/test_rpa.py
index 995099d..325d796 100644
--- a/test/sca/test_rpa.py
+++ b/test/sca/test_rpa.py
@@ -3,6 +3,7 @@ from unittest import TestCase
from parameterized import parameterized
from pyecsca.ec.context import local
+from pyecsca.ec import context
from pyecsca.ec.mult import (
LTRMultiplier,
BinaryNAFMultiplier,
diff --git a/test/utils.py b/test/utils.py
index d4893eb..e125813 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -62,7 +62,7 @@ class Profiler:
if self._state != "out":
raise ValueError
if self._prof_type == "py":
- print(self._prof.output_text(unicode=True, color=True))
+ print(self._prof.output_text(unicode=True, color=True, show_all=True))
else:
self._prof.print_stats("cumtime")