diff options
Diffstat (limited to 'pyecsca/ec/context.py')
| -rw-r--r-- | pyecsca/ec/context.py | 173 |
1 files changed, 45 insertions, 128 deletions
diff --git a/pyecsca/ec/context.py b/pyecsca/ec/context.py index 8f797a8..d6f56af 100644 --- a/pyecsca/ec/context.py +++ b/pyecsca/ec/context.py @@ -1,85 +1,45 @@ -import ast from abc import ABCMeta, abstractmethod +from collections import OrderedDict from contextvars import ContextVar, Token from copy import deepcopy -from typing import List, Tuple, Optional, Union, MutableMapping, Any, ContextManager +from typing import List, Optional, ContextManager, Any from public import public -from .formula import Formula -from .mod import Mod -from .op import CodeOp, OpType -from .point import Point - - -@public -class OpResult(object): - """A result of an operation.""" - parents: Tuple - op: OpType - name: str - value: Mod - - def __init__(self, name: str, value: Mod, op: OpType, *parents: Any): - self.parents = tuple(parents) - self.name = name - self.value = value - self.op = op - - def __str__(self): - return self.name - - def __repr__(self): - char = self.op.op_str - parents = char.join(str(parent) for parent in self.parents) - return f"{self.name} = {parents}" - @public class Action(object): - """An execution of some operations with inputs and outputs.""" - inputs: MutableMapping[str, Mod] - input_points: List[Point] - intermediates: MutableMapping[str, OpResult] - outputs: MutableMapping[str, OpResult] - output_points: List[Point] + """An Action.""" + inside: bool - def __init__(self, *points: Point, **inputs: Mod): - self.inputs = inputs - self.intermediates = {} - self.outputs = {} - self.input_points = list(points) - self.output_points = [] + def __init__(self): + self.inside = False - def add_operation(self, op: CodeOp, value: Mod): - parents: List[Union[Mod, OpResult]] = [] - for parent in {*op.variables, *op.parameters}: - if parent in self.intermediates: - parents.append(self.intermediates[parent]) - elif parent in self.inputs: - parents.append(self.inputs[parent]) - self.intermediates[op.result] = OpResult(op.result, value, op.operator, *parents) + def __enter__(self): + getcontext().enter_action(self) + self.inside = True + return self - def add_result(self, point: Point, **outputs: Mod): - for k in outputs: - self.outputs[k] = self.intermediates[k] - self.output_points.append(point) + def __exit__(self, exc_type, exc_val, exc_tb): + getcontext().exit_action(self) + self.inside = False - def __repr__(self): - return f"{self.__class__.__name__}({self.input_points}) = {self.output_points}" -@public -class FormulaAction(Action): - """An execution of a formula, on some input points and parameters, with some outputs.""" - formula: Formula - def __init__(self, formula: Formula, *points: Point, **inputs: Mod): - super().__init__(*points, **inputs) - self.formula = formula - def __repr__(self): - return f"{self.__class__.__name__}({self.formula}, {self.input_points}) = {self.output_points}" +class Tree(OrderedDict): + + def walk_get(self, path: List) -> Any: + if len(path) == 0: + return self + value = self[path[0]] + if isinstance(value, Tree): + return value.walk_get(path[1:]) + elif len(path) == 1: + return value + else: + raise ValueError @public @@ -87,92 +47,49 @@ class Context(object): __metaclass__ = ABCMeta @abstractmethod - def _log_formula(self, formula: Formula, *points: Point, **inputs: Mod): - ... - - @abstractmethod - def _log_operation(self, op: CodeOp, value: Mod): + def enter_action(self, action: Action): ... @abstractmethod - def _log_result(self, point: Point, **outputs: Mod): + def exit_action(self, action: Action): ... - def _execute(self, formula: Formula, *points: Point, **params: Mod) -> Tuple[Point, ...]: - if len(points) != formula.num_inputs: - raise ValueError(f"Wrong number of inputs for {formula}.") - coords = {} - for i, point in enumerate(points): - if point.coordinate_model != formula.coordinate_model: - raise ValueError(f"Wrong coordinate model of point {point}.") - for coord, value in point.coords.items(): - coords[coord + str(i + 1)] = value - locals = {**coords, **params} - self._log_formula(formula, *points, **locals) - for op in formula.code: - op_result = op(**locals) - self._log_operation(op, op_result) - locals[op.result] = op_result - result = [] - for i in range(formula.num_outputs): - ind = str(i + formula.output_index) - resulting = {} - full_resulting = {} - for variable in formula.coordinate_model.variables: - full_variable = variable + ind - resulting[variable] = locals[full_variable] - full_resulting[full_variable] = locals[full_variable] - point = Point(formula.coordinate_model, **resulting) - - self._log_result(point, **full_resulting) - result.append(point) - return tuple(result) - - def execute(self, formula: Formula, *points: Point, **params: Mod) -> Tuple[Point, ...]: - """ - Execute a formula. - - :param formula: Formula to execute. - :param points: Points to pass into the formula. - :param params: Parameters of the curve. - :return: The resulting point(s). - """ - return self._execute(formula, *points, **params) - def __str__(self): return self.__class__.__name__ @public class NullContext(Context): - """A context that does not trace any operations.""" - - def _log_formula(self, formula: Formula, *points: Point, **inputs: Mod): - pass + """A context that does not trace any actions.""" - def _log_operation(self, op: CodeOp, value: Mod): + def enter_action(self, action: Action): pass - def _log_result(self, point: Point, **outputs: Mod): + def exit_action(self, action: Action): pass @public class DefaultContext(Context): - """A context that traces executions of formulas.""" - actions: List[FormulaAction] + """A context that traces executions of actions.""" + actions: Tree + current: List[Action] - def __init__(self): - self.actions = [] + def enter_action(self, action: Action): + self.actions.walk_get(self.current)[action] = Tree() + self.current.append(action) - def _log_formula(self, formula: Formula, *points: Point, **inputs: Mod): - self.actions.append(FormulaAction(formula, *points, **inputs)) + def exit_action(self, action: Action): + if self.current[-1] != action: + raise ValueError + self.current.pop() - def _log_operation(self, op: CodeOp, value: Mod): - self.actions[-1].add_operation(op, value) + def __init__(self): + self.actions = Tree() + self.current = [] - def _log_result(self, point: Point, **outputs: Mod): - self.actions[-1].add_result(point, **outputs) + def __repr__(self): + return f"{self.__class__.__name__}({self.actions}, current={self.current})" _actual_context: ContextVar[Context] = ContextVar("operational_context", default=NullContext()) |
