aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/ec/formula/__init__.py3
-rw-r--r--pyecsca/ec/formula/code.py71
-rw-r--r--pyecsca/ec/formula/graph.py67
-rw-r--r--test/ec/test_pickle.py13
4 files changed, 89 insertions, 65 deletions
diff --git a/pyecsca/ec/formula/__init__.py b/pyecsca/ec/formula/__init__.py
index b6efd8a..0e28f4f 100644
--- a/pyecsca/ec/formula/__init__.py
+++ b/pyecsca/ec/formula/__init__.py
@@ -1,4 +1,5 @@
-""""""
+"""Provides functionality for working with addition formulas."""
from .base import *
+from .code import *
from .efd import *
diff --git a/pyecsca/ec/formula/code.py b/pyecsca/ec/formula/code.py
new file mode 100644
index 0000000..9d36d94
--- /dev/null
+++ b/pyecsca/ec/formula/code.py
@@ -0,0 +1,71 @@
+"""Provides a concrete class of a formula that has a constructor."""
+
+from .base import (
+ Formula,
+ AdditionFormula,
+ DoublingFormula,
+ LadderFormula,
+ TriplingFormula,
+ NegationFormula,
+ ScalingFormula,
+ DifferentialAdditionFormula,
+)
+
+
+class CodeFormula(Formula):
+ def __init__(self, name, code, coordinate_model, parameters, assumptions):
+ self.name = name
+ self.coordinate_model = coordinate_model
+ self.meta = {}
+ self.parameters = parameters
+ self.assumptions = assumptions
+ self.code = code
+ self.unified = False
+
+ def __hash__(self):
+ return hash(
+ (
+ self.name,
+ self.coordinate_model,
+ tuple(self.code),
+ tuple(self.parameters),
+ tuple(self.assumptions),
+ )
+ )
+
+ def __eq__(self, other):
+ if not isinstance(other, CodeFormula):
+ return False
+ return (
+ self.name == other.name
+ and self.coordinate_model == other.coordinate_model
+ and self.code == other.code
+ )
+
+
+class CodeAdditionFormula(AdditionFormula, CodeFormula):
+ pass
+
+
+class CodeDoublingFormula(DoublingFormula, CodeFormula):
+ pass
+
+
+class CodeLadderFormula(LadderFormula, CodeFormula):
+ pass
+
+
+class CodeTriplingFormula(TriplingFormula, CodeFormula):
+ pass
+
+
+class CodeNegationFormula(NegationFormula, CodeFormula):
+ pass
+
+
+class CodeScalingFormula(ScalingFormula, CodeFormula):
+ pass
+
+
+class CodeDifferentialAdditionFormula(DifferentialAdditionFormula, CodeFormula):
+ pass
diff --git a/pyecsca/ec/formula/graph.py b/pyecsca/ec/formula/graph.py
index 8386973..3520b6a 100644
--- a/pyecsca/ec/formula/graph.py
+++ b/pyecsca/ec/formula/graph.py
@@ -1,13 +1,5 @@
-from . import (
- Formula,
- AdditionFormula,
- DoublingFormula,
- LadderFormula,
- TriplingFormula,
- NegationFormula,
- ScalingFormula,
- DifferentialAdditionFormula,
-)
+from .base import Formula
+from .code import CodeFormula
from ..op import CodeOp, OpType
import matplotlib.pyplot as plt
import networkx as nx
@@ -205,57 +197,6 @@ def formula_input_variables(formula: Formula) -> List[str]:
)
-class CodeFormula(Formula):
- def __init__(self, name, code, coordinate_model, parameters, assumptions):
- self.name = name
- self.coordinate_model = coordinate_model
- self.meta = {}
- self.parameters = parameters
- self.assumptions = assumptions
- self.code = code
- self.unified = False
-
- def __hash__(self):
- return hash((self.name, self.coordinate_model, tuple(self.code), tuple(self.parameters), tuple(self.assumptions)))
-
- def __eq__(self, other):
- if not isinstance(other, CodeFormula):
- return False
- return (
- self.name == other.name
- and self.coordinate_model == other.coordinate_model
- and self.code == other.code
- )
-
-
-class CodeAdditionFormula(AdditionFormula, CodeFormula):
- pass
-
-
-class CodeDoublingFormula(DoublingFormula, CodeFormula):
- pass
-
-
-class CodeLadderFormula(LadderFormula, CodeFormula):
- pass
-
-
-class CodeTriplingFormula(TriplingFormula, CodeFormula):
- pass
-
-
-class CodeNegationFormula(NegationFormula, CodeFormula):
- pass
-
-
-class CodeScalingFormula(ScalingFormula, CodeFormula):
- pass
-
-
-class CodeDifferentialAdditionFormula(DifferentialAdditionFormula, CodeFormula):
- pass
-
-
class FormulaGraph:
coordinate_model: Any
shortname: str
@@ -323,9 +264,7 @@ class FormulaGraph:
assumptions = [deepcopy(assumption) for assumption in self.assumptions]
for klass in CodeFormula.__subclasses__():
if klass.shortname == self.shortname:
- return klass(
- name, code, self.coordinate_model, parameters, assumptions
- )
+ return klass(name, code, self.coordinate_model, parameters, assumptions)
raise ValueError(f"Bad formula type: {self.shortname}")
def networkx_graph(self) -> nx.DiGraph:
diff --git a/test/ec/test_pickle.py b/test/ec/test_pickle.py
index 5d448b5..36f014d 100644
--- a/test/ec/test_pickle.py
+++ b/test/ec/test_pickle.py
@@ -55,6 +55,19 @@ def test_formula():
assert formulas == back
+def formula_target(formula):
+ return hasattr(formula, "coordinate_model")
+
+
+def test_formula_loads(ctx):
+ sw = ShortWeierstrassModel()
+ coords = sw.coordinates["projective"]
+ formula = coords.formulas["add-2007-bl"]
+ with ctx.Pool(processes=1) as pool:
+ res = pool.apply(formula_target, args=(formula,))
+ assert res
+
+
def test_code_formula():
sw = ShortWeierstrassModel()
coords = sw.coordinates["projective"]