aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/ec
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/ec')
-rw-r--r--pyecsca/ec/coordinates.py8
-rw-r--r--pyecsca/ec/formula/__init__.py4
-rw-r--r--pyecsca/ec/formula/base.py (renamed from pyecsca/ec/formula.py)146
-rw-r--r--pyecsca/ec/formula/efd.py142
-rw-r--r--pyecsca/ec/formula/expand.py51
-rw-r--r--pyecsca/ec/formula/fliparoo.py378
-rw-r--r--pyecsca/ec/formula/graph.py436
-rw-r--r--pyecsca/ec/formula/metrics.py100
-rw-r--r--pyecsca/ec/formula/partitions.py378
-rw-r--r--pyecsca/ec/formula/switch_sign.py131
-rw-r--r--pyecsca/ec/mult/window.py170
-rw-r--r--pyecsca/ec/point.py2
-rw-r--r--pyecsca/ec/scalar.py63
13 files changed, 1846 insertions, 163 deletions
diff --git a/pyecsca/ec/coordinates.py b/pyecsca/ec/coordinates.py
index 49452c7..8692809 100644
--- a/pyecsca/ec/coordinates.py
+++ b/pyecsca/ec/coordinates.py
@@ -7,8 +7,8 @@ from typing import List, Any, MutableMapping
from public import public
-from .formula import (
- Formula,
+from .formula import Formula
+from .formula.efd import (
EFDFormula,
AdditionEFDFormula,
DoublingEFDFormula,
@@ -55,7 +55,9 @@ class CoordinateModel:
return f"{self.curve_model.shortname}/{self.name}"
def __repr__(self):
- return f"{self.__class__.__name__}(\"{self.name}\", curve_model={self.curve_model})"
+ return (
+ f'{self.__class__.__name__}("{self.name}", curve_model={self.curve_model})'
+ )
def __getstate__(self):
state = self.__dict__.copy()
diff --git a/pyecsca/ec/formula/__init__.py b/pyecsca/ec/formula/__init__.py
new file mode 100644
index 0000000..b6efd8a
--- /dev/null
+++ b/pyecsca/ec/formula/__init__.py
@@ -0,0 +1,4 @@
+""""""
+
+from .base import *
+from .efd import *
diff --git a/pyecsca/ec/formula.py b/pyecsca/ec/formula/base.py
index fff3210..733a82d 100644
--- a/pyecsca/ec/formula.py
+++ b/pyecsca/ec/formula/base.py
@@ -1,24 +1,22 @@
-"""Provides an abstract base class of a formula along with concrete instantiations."""
+"""Provides an abstract base class of a formula."""
from abc import ABC, abstractmethod
-from ast import parse, Expression
+from ast import Expression
from functools import cached_property
from astunparse import unparse
-from itertools import product
from typing import List, Set, Any, ClassVar, MutableMapping, Tuple, Union, Dict
-from importlib_resources.abc import Traversable
from public import public
from sympy import FF, symbols, Poly, Rational, simplify
-from ..misc.cache import sympify
-from .context import ResultAction
-from . import context
-from .error import UnsatisfiedAssumptionError, raise_unsatisified_assumption
-from .mod import Mod, SymbolicMod
-from .op import CodeOp, OpType
-from ..misc.cfg import getconfig
-from ..misc.utils import peval
+from ..context import ResultAction
+from .. import context
+from ..error import UnsatisfiedAssumptionError, raise_unsatisified_assumption
+from ..mod import Mod, SymbolicMod
+from ..op import CodeOp, OpType
+from ...misc.cfg import getconfig
+from ...misc.utils import peval
+from ...misc.cache import sympify
@public
@@ -235,7 +233,7 @@ class Formula(ABC):
:param params: Parameters of the curve.
:return: The resulting point(s).
"""
- from .point import Point
+ from ..point import Point
self.__validate_params(field, params)
self.__validate_points(field, points, params)
@@ -351,93 +349,6 @@ class Formula(ABC):
)
-class EFDFormula(Formula):
- """Formula from the [EFD]_."""
-
- def __init__(
- self,
- meta_path: Traversable,
- op3_path: Traversable,
- name: str,
- coordinate_model: Any,
- ):
- self.name = name
- self.coordinate_model = coordinate_model
- self.meta = {}
- self.parameters = []
- self.assumptions = []
- self.code = []
- self.unified = False
- self.__read_meta_file(meta_path)
- self.__read_op3_file(op3_path)
-
- def __read_meta_file(self, path: Traversable):
- with path.open("rb") as f:
- line = f.readline().decode("ascii").rstrip()
- while line:
- if line.startswith("source"):
- self.meta["source"] = line[7:]
- elif line.startswith("parameter"):
- self.parameters.append(line[10:])
- elif line.startswith("assume"):
- self.assumptions.append(
- parse(
- line[7:].replace("=", "==").replace("^", "**"), mode="eval"
- )
- )
- elif line.startswith("unified"):
- self.unified = True
- line = f.readline().decode("ascii").rstrip()
-
- def __read_op3_file(self, path: Traversable):
- with path.open("rb") as f:
- for line in f.readlines():
- code_module = parse(
- line.decode("ascii").replace("^", "**"), str(path), mode="exec"
- )
- self.code.append(CodeOp(code_module))
-
- def __str__(self):
- return f"{self.coordinate_model!s}/{self.name}"
-
- @cached_property
- def input_index(self):
- return 1
-
- @cached_property
- def output_index(self):
- return max(self.num_inputs + 1, 3)
-
- @cached_property
- def inputs(self):
- return {
- var + str(i)
- for var, i in product(
- self.coordinate_model.variables, range(1, 1 + self.num_inputs)
- )
- }
-
- @cached_property
- def outputs(self):
- return {
- var + str(i)
- for var, i in product(
- self.coordinate_model.variables,
- range(self.output_index, self.output_index + self.num_outputs),
- )
- }
-
- def __eq__(self, other):
- if not isinstance(other, EFDFormula):
- return False
- return (
- self.name == other.name and self.coordinate_model == other.coordinate_model
- )
-
- def __hash__(self):
- return hash((self.coordinate_model, self.name))
-
-
@public
class AdditionFormula(Formula, ABC):
"""Formula that adds two points."""
@@ -448,11 +359,6 @@ class AdditionFormula(Formula, ABC):
@public
-class AdditionEFDFormula(AdditionFormula, EFDFormula):
- pass
-
-
-@public
class DoublingFormula(Formula, ABC):
"""Formula that doubles a point."""
@@ -462,11 +368,6 @@ class DoublingFormula(Formula, ABC):
@public
-class DoublingEFDFormula(DoublingFormula, EFDFormula):
- pass
-
-
-@public
class TriplingFormula(Formula, ABC):
"""Formula that triples a point."""
@@ -476,11 +377,6 @@ class TriplingFormula(Formula, ABC):
@public
-class TriplingEFDFormula(TriplingFormula, EFDFormula):
- pass
-
-
-@public
class NegationFormula(Formula, ABC):
"""Formula that negates a point."""
@@ -490,11 +386,6 @@ class NegationFormula(Formula, ABC):
@public
-class NegationEFDFormula(NegationFormula, EFDFormula):
- pass
-
-
-@public
class ScalingFormula(Formula, ABC):
"""Formula that somehow scales the point (to a given representative of a projective class)."""
@@ -504,11 +395,6 @@ class ScalingFormula(Formula, ABC):
@public
-class ScalingEFDFormula(ScalingFormula, EFDFormula):
- pass
-
-
-@public
class DifferentialAdditionFormula(Formula, ABC):
"""
Differential addition formula that adds two points with a known difference.
@@ -522,11 +408,6 @@ class DifferentialAdditionFormula(Formula, ABC):
@public
-class DifferentialAdditionEFDFormula(DifferentialAdditionFormula, EFDFormula):
- pass
-
-
-@public
class LadderFormula(Formula, ABC):
"""
Ladder formula for simultaneous addition of two points and doubling of the one of them, with a known difference.
@@ -539,8 +420,3 @@ class LadderFormula(Formula, ABC):
shortname = "ladd"
num_inputs = 3
num_outputs = 2
-
-
-@public
-class LadderEFDFormula(LadderFormula, EFDFormula):
- pass
diff --git a/pyecsca/ec/formula/efd.py b/pyecsca/ec/formula/efd.py
new file mode 100644
index 0000000..0156fb3
--- /dev/null
+++ b/pyecsca/ec/formula/efd.py
@@ -0,0 +1,142 @@
+""""""
+from functools import cached_property
+from itertools import product
+
+from public import public
+
+from importlib_resources.abc import Traversable
+from typing import Any
+from .base import (
+ Formula,
+ CodeOp,
+ AdditionFormula,
+ DoublingFormula,
+ TriplingFormula,
+ NegationFormula,
+ ScalingFormula,
+ DifferentialAdditionFormula,
+ LadderFormula,
+)
+from ast import parse
+
+
+class EFDFormula(Formula):
+ """Formula from the [EFD]_."""
+
+ def __init__(
+ self,
+ meta_path: Traversable,
+ op3_path: Traversable,
+ name: str,
+ coordinate_model: Any,
+ ):
+ self.name = name
+ self.coordinate_model = coordinate_model
+ self.meta = {}
+ self.parameters = []
+ self.assumptions = []
+ self.code = []
+ self.unified = False
+ self.__read_meta_file(meta_path)
+ self.__read_op3_file(op3_path)
+
+ def __read_meta_file(self, path: Traversable):
+ with path.open("rb") as f:
+ line = f.readline().decode("ascii").rstrip()
+ while line:
+ if line.startswith("source"):
+ self.meta["source"] = line[7:]
+ elif line.startswith("parameter"):
+ self.parameters.append(line[10:])
+ elif line.startswith("assume"):
+ self.assumptions.append(
+ parse(
+ line[7:].replace("=", "==").replace("^", "**"), mode="eval"
+ )
+ )
+ elif line.startswith("unified"):
+ self.unified = True
+ line = f.readline().decode("ascii").rstrip()
+
+ def __read_op3_file(self, path: Traversable):
+ with path.open("rb") as f:
+ for line in f.readlines():
+ code_module = parse(
+ line.decode("ascii").replace("^", "**"), str(path), mode="exec"
+ )
+ self.code.append(CodeOp(code_module))
+
+ def __str__(self):
+ return f"{self.coordinate_model!s}/{self.name}"
+
+ @cached_property
+ def input_index(self):
+ return 1
+
+ @cached_property
+ def output_index(self):
+ return max(self.num_inputs + 1, 3)
+
+ @cached_property
+ def inputs(self):
+ return {
+ var + str(i)
+ for var, i in product(
+ self.coordinate_model.variables, range(1, 1 + self.num_inputs)
+ )
+ }
+
+ @cached_property
+ def outputs(self):
+ return {
+ var + str(i)
+ for var, i in product(
+ self.coordinate_model.variables,
+ range(self.output_index, self.output_index + self.num_outputs),
+ )
+ }
+
+ def __eq__(self, other):
+ if not isinstance(other, EFDFormula):
+ return False
+ return (
+ self.name == other.name and self.coordinate_model == other.coordinate_model
+ )
+
+ def __hash__(self):
+ return hash((self.coordinate_model, self.name))
+
+
+@public
+class AdditionEFDFormula(AdditionFormula, EFDFormula):
+ pass
+
+
+@public
+class DoublingEFDFormula(DoublingFormula, EFDFormula):
+ pass
+
+
+@public
+class TriplingEFDFormula(TriplingFormula, EFDFormula):
+ pass
+
+
+@public
+class NegationEFDFormula(NegationFormula, EFDFormula):
+ pass
+
+
+@public
+class ScalingEFDFormula(ScalingFormula, EFDFormula):
+ pass
+
+
+@public
+class DifferentialAdditionEFDFormula(DifferentialAdditionFormula, EFDFormula):
+ pass
+
+
+@public
+class LadderEFDFormula(LadderFormula, EFDFormula):
+ pass
diff --git a/pyecsca/ec/formula/expand.py b/pyecsca/ec/formula/expand.py
new file mode 100644
index 0000000..3265131
--- /dev/null
+++ b/pyecsca/ec/formula/expand.py
@@ -0,0 +1,51 @@
+from typing import List, Callable, Any
+from public import public
+
+from . import Formula
+from .efd import EFDFormula
+from .fliparoo import recursive_fliparoo
+from .graph import ModifiedEFDFormula
+from .metrics import ivs_norm
+from .partitions import reduce_all_adds, expand_all_muls, expand_all_nopower2_muls
+from .switch_sign import generate_switched_formulas
+
+
+def reduce_with_similarity(formulas: List[EFDFormula], norm):
+ efd = list(filter(lambda x: not isinstance(x, ModifiedEFDFormula), formulas))
+ reduced_efd = efd
+ similarities = list(map(norm, efd))
+ for formula in formulas:
+ n = norm(formula)
+ if n in similarities:
+ continue
+ similarities.append(n)
+ reduced_efd.append(formula)
+ return reduced_efd
+
+
+@public
+def expand_formula_list(
+ formulas: List[EFDFormula], norm: Callable[[Formula], Any] = ivs_norm
+) -> List[EFDFormula]:
+ extended = reduce_with_similarity(formulas, norm)
+
+ fliparood: List[EFDFormula] = sum(list(map(recursive_fliparoo, extended)), [])
+ extended.extend(fliparood)
+ extended = reduce_with_similarity(extended, norm)
+
+ switch_signs: List[EFDFormula] = sum(
+ [list(generate_switched_formulas(f)) for f in extended], []
+ )
+ extended.extend(switch_signs)
+ extended = reduce_with_similarity(extended, norm)
+
+ extended.extend(list(map(reduce_all_adds, extended)))
+ extended = reduce_with_similarity(extended, norm)
+
+ extended.extend(list(map(expand_all_muls, extended)))
+ extended = reduce_with_similarity(extended, norm)
+
+ extended.extend(list(map(expand_all_nopower2_muls, extended)))
+ extended = reduce_with_similarity(extended, norm)
+
+ return extended
diff --git a/pyecsca/ec/formula/fliparoo.py b/pyecsca/ec/formula/fliparoo.py
new file mode 100644
index 0000000..c8d77ac
--- /dev/null
+++ b/pyecsca/ec/formula/fliparoo.py
@@ -0,0 +1,378 @@
+from typing import Iterator, List, Tuple, Type, Optional
+from ..op import OpType
+from .graph import EFDFormulaGraph, Node, CodeOpNode, CodeOp, parse
+from .efd import EFDFormula
+from random import randint
+
+
+class Fliparoo:
+ """
+ Fliparoo is a chain of nodes N1->N2->...->Nk in EFDFormulaGraph for k>=2 such that:
+ - All Ni are * or All Ni are +/-
+ - For every Ni, except for Nk, the only outgoing node is Ni+1
+ - Neither of N1,...,Nk-1 is an output node
+ """
+
+ nodes: List[CodeOpNode]
+ graph: EFDFormulaGraph
+ operator: Optional[OpType]
+
+ def __init__(self, chain: List[CodeOpNode], graph: EFDFormulaGraph):
+ self.verify_chain(chain)
+ self.nodes = chain
+ self.graph = graph
+ self.operator = None
+
+ def verify_chain(self, nodes: List[CodeOpNode]):
+ for i, node in enumerate(nodes[:-1]):
+ if node.outgoing_nodes != [nodes[i + 1]]:
+ raise BadFliparoo
+ if node.output_node:
+ raise BadFliparoo
+
+ @property
+ def first(self):
+ return self.nodes[0]
+
+ @property
+ def last(self):
+ return self.nodes[-1]
+
+ def __len__(self):
+ return len(self.nodes)
+
+ def __repr__(self):
+ return "->".join(map(lambda x: x.__repr__(), self.nodes))
+
+ def previous(self, node: CodeOpNode) -> Optional[CodeOpNode]:
+ i = self.nodes.index(node)
+ if i == 0:
+ return None
+ return self.nodes[i - 1]
+
+ def __getitem__(self, i: int):
+ return self.nodes[i]
+
+ def __iter__(self):
+ return iter(self.nodes)
+
+ def __eq__(self, other):
+ return self.graph == other.graph and self.nodes == other.nodes
+
+ def deepcopy(self):
+ ngraph = self.graph.deepcopy()
+ nchain = [mirror_node(node, self.graph, ngraph) for node in self.nodes]
+ return self.__class__(nchain, ngraph)
+
+ def input_nodes(self) -> List[Node]:
+ input_nodes: List[Node] = []
+ for node in self:
+ input_nodes.extend(
+ filter(lambda x: x not in self.nodes, node.incoming_nodes)
+ )
+ return input_nodes
+
+ def slice(self, i: int, j: int):
+ return self.__class__(self.nodes[i:j], self.graph)
+
+
+class MulFliparoo(Fliparoo):
+ def __init__(self, chain: List[CodeOpNode], graph: EFDFormulaGraph):
+ super().__init__(chain, graph)
+ operations = set(node.op.operator for node in self.nodes)
+ if len(operations) != 1 or list(operations)[0] != OpType.Mult:
+ raise BadFliparoo
+ self.operator = OpType.Mult
+
+
+class AddSubFliparoo(Fliparoo):
+ def __init__(self, chain: List[CodeOpNode], graph: EFDFormulaGraph):
+ super().__init__(chain, graph)
+ operations = set(node.op.operator for node in self.nodes)
+ if not operations.issubset([OpType.Add, OpType.Sub]):
+ raise BadFliparoo
+
+
+class AddFliparoo(Fliparoo):
+ def __init__(self, chain: List[CodeOpNode], graph: EFDFormulaGraph):
+ super().__init__(chain, graph)
+ operations = set(node.op.operator for node in self.nodes)
+ if len(operations) != 1 or list(operations)[0] != OpType.Add:
+ raise BadFliparoo
+ self.operator = OpType.Add
+
+
+class BadFliparoo(Exception):
+ pass
+
+
+def find_fliparoos(
+ graph: EFDFormulaGraph, fliparoo_type: Optional[Type[Fliparoo]] = None
+) -> List[Fliparoo]:
+ """Finds a list of Fliparoos in a graph"""
+ fliparoos: List[Fliparoo] = []
+ for ilong_path in graph.find_all_paths():
+ long_path = ilong_path[1:] # get rid of the input variables
+ fliparoo = largest_fliparoo(long_path, graph, fliparoo_type) # type: ignore
+ if fliparoo and fliparoo not in fliparoos:
+ fliparoos.append(fliparoo)
+
+ # remove duplicities and fliparoos in inclusion
+ fliparoos = sorted(fliparoos, key=len, reverse=True)
+ longest_fliparoos: List[Fliparoo] = []
+ for fliparoo in fliparoos:
+ if not is_subfliparoo(fliparoo, longest_fliparoos):
+ longest_fliparoos.append(fliparoo)
+ return longest_fliparoos
+
+
+def is_subfliparoo(fliparoo: Fliparoo, longest_fliparoos: List[Fliparoo]) -> bool:
+ """Returns true if fliparoo is a part of any fliparoo in longest_fliparoos"""
+ for other_fliparoo in longest_fliparoos:
+ l1, l2 = len(fliparoo), len(other_fliparoo)
+ for i in range(l2 - l1):
+ if other_fliparoo.slice(i, i + l1) == fliparoo:
+ return True
+ return False
+
+
+def largest_fliparoo(
+ chain: List[CodeOpNode],
+ graph: EFDFormulaGraph,
+ fliparoo_type: Optional[Type[Fliparoo]] = None,
+) -> Optional[Fliparoo]:
+ """Finds the largest fliparoo in a list of Nodes"""
+ for i in range(len(chain) - 1):
+ for j in range(len(chain) - 1, i, -1):
+ subchain = chain[i : j + 1]
+ if fliparoo_type:
+ try:
+ fliparoo_type(subchain, graph)
+ except BadFliparoo:
+ continue
+ try:
+ return MulFliparoo(subchain, graph)
+ except BadFliparoo:
+ pass
+ try:
+ return AddSubFliparoo(subchain, graph)
+ except BadFliparoo:
+ pass
+ return None
+
+
+class SignedNode:
+
+ """
+ Represents a summand in an expression X1-X2+X3+X4-X5...
+ Used for creating +/- Fliparoos
+ """
+
+ node: CodeOpNode
+ sign: int
+
+ def __init__(self, node: CodeOpNode):
+ self.node = node
+ self.sign = 1
+
+ def __repr__(self):
+ s = {1: "+", -1: "-"}[self.sign]
+ return f"{s}{self.node.__repr__()}"
+
+
+class SignedSubGraph:
+ """Subgraph of an EFDFormula graph with signed nodes"""
+
+ def __init__(self, nodes: List[SignedNode], graph: EFDFormulaGraph):
+ self.nodes = nodes
+ self.supergraph = graph
+
+ def add_node(self, node: SignedNode):
+ self.nodes.append(node)
+
+ def remove_node(self, node: SignedNode):
+ self.nodes.remove(node)
+
+ def change_signs(self):
+ for node in self.nodes:
+ node.sign *= -1
+
+ def __getitem__(self, i):
+ return self.nodes[i]
+
+ @property
+ def components(self):
+ return len(self.nodes)
+
+ def deepcopy(self):
+ sgraph = self.supergraph.deepcopy()
+ return SignedSubGraph(
+ [mirror_node(n, self.supergraph, sgraph) for n in self.nodes], sgraph
+ )
+
+
+def mirror_node(node, graph, graph_copy):
+ """Finds the corresponding node in a copy of the graph"""
+ if isinstance(node, SignedNode):
+ ns = SignedNode(graph_copy.nodes[graph.node_index(node.node)])
+ ns.sign = node.sign
+ return ns
+ if isinstance(node, Node):
+ return graph_copy.nodes[graph.node_index(node)]
+ raise NotImplementedError
+
+
+class DummyNode(Node):
+ def __repr__(self):
+ return "Dummy node"
+
+ def label(self):
+ pass
+
+ def result(self):
+ pass
+
+
+def generate_fliparood_formulas(
+ formula: EFDFormula, rename: bool = True
+) -> Iterator[EFDFormula]:
+ graph = EFDFormulaGraph(formula, rename)
+ fliparoos = find_fliparoos(graph)
+ for fliparoo in fliparoos:
+ for flip_graph in generate_fliparood_graphs(fliparoo):
+ yield flip_graph.to_EFDFormula()
+
+
+def generate_fliparood_graphs(fliparoo: Fliparoo) -> Iterator[EFDFormulaGraph]:
+ fliparoo = fliparoo.deepcopy()
+ last_str = fliparoo.last.result
+ disconnect_fliparoo_outputs(fliparoo)
+
+ signed_subgraph = extract_fliparoo_signed_inputs(fliparoo)
+
+ # Starting with a single list of unconnected signed nodes
+ signed_subgraphs = [signed_subgraph]
+ for _ in range(signed_subgraph.components - 1):
+ subgraphs_updated = []
+ for signed_subgraph in signed_subgraphs:
+ extended_subgraphs = combine_all_pairs_signed_nodes(
+ signed_subgraph, fliparoo
+ )
+ subgraphs_updated.extend(extended_subgraphs)
+ signed_subgraphs = subgraphs_updated
+
+ for signed_subgraph in signed_subgraphs:
+ graph = signed_subgraph.supergraph
+ assert signed_subgraph.components == 1
+ final_signed_node = signed_subgraph.nodes[0]
+ if final_signed_node.sign != 1:
+ continue
+ final_node: CodeOpNode = final_signed_node.node
+
+ opstr = f"{last_str} = {final_node.op.left}{final_node.optype.op_str}{final_node.op.right}"
+ final_node.op = CodeOp(parse(opstr))
+ reconnect_fliparoo_outputs(graph, final_node)
+ graph.update()
+ yield graph
+
+
+def extract_fliparoo_signed_inputs(fliparoo: Fliparoo) -> SignedSubGraph:
+ graph = fliparoo.graph
+ signed_inputs = SignedSubGraph([], graph)
+ for node in fliparoo:
+ prev = fliparoo.previous(node)
+ left, right = map(SignedNode, node.incoming_nodes)
+ if right.node != prev:
+ right.sign = -1 if node.is_sub else 1
+ signed_inputs.add_node(right)
+ if left.node != prev:
+ if node.is_sub and right.node == prev:
+ signed_inputs.change_signs()
+ signed_inputs.add_node(left)
+ if prev:
+ graph.remove_node(prev)
+ graph.remove_node(fliparoo.last)
+ return signed_inputs
+
+
+def disconnect_fliparoo_outputs(fliparoo: Fliparoo):
+ # remember positions of the last node with a DummyNode placeholder
+ dummy = DummyNode()
+ fliparoo.graph.add_node(dummy)
+ fliparoo.last.reconnect_outgoing_nodes(dummy)
+
+
+def reconnect_fliparoo_outputs(graph: EFDFormulaGraph, last_node: Node):
+ dummy = next(filter(lambda x: isinstance(x, DummyNode), graph.nodes))
+ dummy.reconnect_outgoing_nodes(last_node)
+ graph.remove_node(dummy)
+ assert not list(filter(lambda x: isinstance(x, DummyNode), graph.nodes))
+
+
+def combine_all_pairs_signed_nodes(
+ signed_subgraph: SignedSubGraph, fliparoo: Fliparoo
+) -> List[SignedSubGraph]:
+ signed_subgraphs = []
+ n_components = signed_subgraph.components
+ for i in range(n_components):
+ for j in range(i + 1, n_components):
+ csigned_subgraph = signed_subgraph.deepcopy()
+ v, w = csigned_subgraph[i], csigned_subgraph[j]
+ combine_signed_nodes(csigned_subgraph, v, w, fliparoo)
+ signed_subgraphs.append(csigned_subgraph)
+ return signed_subgraphs
+
+
+def combine_signed_nodes(
+ subgraph: SignedSubGraph,
+ left_signed_node: SignedNode,
+ right_signed_node: SignedNode,
+ fliparoo: Fliparoo,
+):
+ left_node, right_node = left_signed_node.node, right_signed_node.node
+ sign = 1
+ operator = OpType.Mult
+ if isinstance(fliparoo, AddSubFliparoo):
+ s0, s1 = left_signed_node.sign, right_signed_node.sign
+ if s0 == 1:
+ operator = OpType.Add if s1 == 1 else OpType.Sub
+
+ if s0 == -1 and s1 == 1:
+ operator = OpType.Sub
+ left_node, right_node = right_node, left_node
+
+ # propagate the sign
+ if s0 == -1 and s1 == -1:
+ operator = OpType.Add
+ sign = -1
+
+ new_node = CodeOpNode.from_str(
+ f"Fliparoo{id(left_node)}_{id(operator)}_{id(sign)}_{id(right_node)}", left_node.result, operator, right_node.result
+ )
+ new_node.incoming_nodes = [left_node, right_node]
+ left_node.outgoing_nodes.append(new_node)
+ right_node.outgoing_nodes.append(new_node)
+ subgraph.supergraph.add_node(new_node)
+ new_node = SignedNode(new_node)
+ new_node.sign = sign
+ subgraph.remove_node(left_signed_node)
+ subgraph.remove_node(right_signed_node)
+ subgraph.add_node(new_node)
+
+
+def recursive_fliparoo(formula, depth=2):
+ all_fliparoos = {0: [formula]}
+ counter = 0
+ while depth > counter:
+ prev_level = all_fliparoos[counter]
+ fliparoo_level = []
+ for flipparood_formula in prev_level:
+ rename = not counter # rename ivs before the first fliparoo
+ for newly_fliparood in generate_fliparood_formulas(
+ flipparood_formula, rename
+ ):
+ fliparoo_level.append(newly_fliparood)
+ counter += 1
+ all_fliparoos[counter] = fliparoo_level
+
+ return sum(all_fliparoos.values(), [])
diff --git a/pyecsca/ec/formula/graph.py b/pyecsca/ec/formula/graph.py
new file mode 100644
index 0000000..1473858
--- /dev/null
+++ b/pyecsca/ec/formula/graph.py
@@ -0,0 +1,436 @@
+from .efd import (
+ EFDFormula,
+ DoublingEFDFormula,
+ AdditionEFDFormula,
+ LadderEFDFormula,
+ DifferentialAdditionEFDFormula,
+)
+from ..op import CodeOp, OpType
+import matplotlib.pyplot as plt
+import networkx as nx
+from ast import parse
+from typing import Dict, List, Tuple, Set, Optional, MutableMapping, Any
+from copy import deepcopy
+from abc import ABC, abstractmethod
+
+
+class Node(ABC):
+ def __init__(self):
+ self.incoming_nodes = []
+ self.outgoing_nodes = []
+ self.output_node = False
+ self.input_node = False
+
+ @property
+ @abstractmethod
+ def label(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def result(self) -> str:
+ pass
+
+ @property
+ def is_sub(self) -> bool:
+ return False
+
+ @property
+ def is_mul(self) -> bool:
+ return False
+
+ @property
+ def is_add(self) -> bool:
+ return False
+
+ @property
+ def is_id(self) -> bool:
+ return False
+
+ @property
+ def is_sqr(self) -> bool:
+ return False
+
+ @property
+ def is_pow(self) -> bool:
+ return False
+
+ @property
+ def is_inv(self) -> bool:
+ return False
+
+ @property
+ def is_div(self) -> bool:
+ return False
+
+ @property
+ def is_neg(self) -> bool:
+ return False
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ pass
+
+ def reconnect_outgoing_nodes(self, destination):
+ destination.output_node = self.output_node
+ for out in self.outgoing_nodes:
+ out.incoming_nodes = [
+ n if n != self else destination for n in out.incoming_nodes
+ ]
+ destination.outgoing_nodes.append(out)
+
+
+class ConstantNode(Node):
+ color = "#b41f44"
+
+ def __init__(self, i: int):
+ super().__init__()
+ self.value = i
+
+ @property
+ def label(self) -> str:
+ return str(self.value)
+
+ @property
+ def result(self) -> str:
+ return str(self.value)
+
+ def __repr__(self) -> str:
+ return f"Node({self.value})"
+
+
+class CodeOpNode(Node):
+ color = "#1f78b4"
+
+ def __init__(self, op: CodeOp):
+ super().__init__()
+ self.op = op
+ assert self.op.operator in [
+ OpType.Sub,
+ OpType.Add,
+ OpType.Id,
+ OpType.Sqr,
+ OpType.Mult,
+ OpType.Div,
+ OpType.Pow,
+ OpType.Inv,
+ OpType.Neg,
+ ], self.op.operator
+
+ @classmethod
+ def from_str(cls, result: str, left, operator, right):
+ opstr = f"{result} = {left if left is not None else ''}{operator.op_str}{right if right is not None else ''}"
+ return cls(CodeOp(parse(opstr.replace("^", "**"))))
+
+ @property
+ def label(self) -> str:
+ return f"{self.op.result}:{self.op.operator.op_str}"
+
+ @property
+ def result(self) -> str:
+ return str(self.op.result)
+
+ @property
+ def optype(self) -> OpType:
+ return self.op.operator
+
+ @property
+ def is_sub(self) -> bool:
+ return self.optype == OpType.Sub
+
+ @property
+ def is_mul(self) -> bool:
+ return self.optype == OpType.Mult
+
+ @property
+ def is_add(self) -> bool:
+ return self.optype == OpType.Add
+
+ @property
+ def is_id(self) -> bool:
+ return self.optype == OpType.Id
+
+ @property
+ def is_sqr(self) -> bool:
+ return self.optype == OpType.Sqr
+
+ @property
+ def is_pow(self) -> bool:
+ return self.optype == OpType.Pow
+
+ @property
+ def is_inv(self) -> bool:
+ return self.optype == OpType.Inv
+
+ @property
+ def is_div(self) -> bool:
+ return self.optype == OpType.Div
+
+ @property
+ def is_neg(self) -> bool:
+ return self.optype == OpType.Neg
+
+ def __repr__(self) -> str:
+ return f"Node({self.op})"
+
+
+class InputNode(Node):
+ color = "#b41f44"
+
+ def __init__(self, input: str):
+ super().__init__()
+ self.input = input
+ self.input_node = True
+
+ @property
+ def label(self) -> str:
+ return self.input
+
+ @property
+ def result(self) -> str:
+ return self.input
+
+ def __repr__(self) -> str:
+ return f"Node({self.input})"
+
+
+def formula_input_variables(formula: EFDFormula) -> List[str]:
+ return (
+ list(formula.inputs)
+ + formula.parameters
+ + formula.coordinate_model.curve_model.parameter_names
+ )
+
+
+# temporary solution
+class ModifiedEFDFormula(EFDFormula):
+ def __eq__(self, other):
+ if not isinstance(other, ModifiedEFDFormula):
+ return False
+ return (
+ self.name == other.name and self.coordinate_model == other.coordinate_model and self.code == other.code
+ )
+
+
+class ModifiedDoublingEFDFormula(DoublingEFDFormula, ModifiedEFDFormula):
+ pass
+
+
+class ModifiedAdditionEFDFormula(AdditionEFDFormula, ModifiedEFDFormula):
+ pass
+
+
+class ModifiedDifferentialAdditionEFDFormula(
+ DifferentialAdditionEFDFormula, ModifiedEFDFormula
+):
+ pass
+
+
+class ModifiedLadderEFDFormula(LadderEFDFormula, ModifiedEFDFormula):
+ pass
+
+
+class EFDFormulaGraph:
+ nodes: List[Node]
+ input_nodes: MutableMapping[str, InputNode]
+ output_names: Set[str]
+ roots: List[Node]
+ coordinate_model: Any
+
+ def __init__(self, formula: EFDFormula, rename=True):
+ self._formula = formula # TODO remove, its here only for to_EFDFormula
+ self.coordinate_model = formula.coordinate_model
+ self.output_names = formula.outputs
+ self.input_nodes = {v: InputNode(v) for v in formula_input_variables(formula)}
+ self.roots = list(self.input_nodes.values())
+ self.nodes = self.roots.copy()
+ discovered_nodes: Dict[str, Node] = self.input_nodes.copy() # type: ignore
+ constants: Dict[int, Node] = {}
+ for op in formula.code:
+ code_node = CodeOpNode(op)
+ for side in (op.left, op.right):
+ if side is None:
+ continue
+ if isinstance(side, int):
+ if side in constants:
+ parent_node = constants[side]
+ else:
+ parent_node = ConstantNode(side)
+ self.nodes.append(parent_node)
+ self.roots.append(parent_node)
+ else:
+ parent_node = discovered_nodes[side]
+ parent_node.outgoing_nodes.append(code_node)
+ code_node.incoming_nodes.append(parent_node)
+ self.nodes.append(code_node)
+ discovered_nodes[op.result] = code_node
+ # flag output nodes
+ for output_name in self.output_names:
+ discovered_nodes[output_name].output_node = True
+
+ # go through the nodes and make sure that every node is root or has parents
+ for node in self.nodes:
+ if not node.incoming_nodes and node not in self.roots:
+ self.roots.append(node)
+ if rename:
+ self.reindex()
+
+ def node_index(self, node: Node) -> int:
+ return self.nodes.index(node)
+
+ def deepcopy(self):
+ return deepcopy(self)
+
+ def to_EFDFormula(self) -> ModifiedEFDFormula:
+ # TODO rewrite
+ new_graph = deepcopy(self)
+ new_formula = new_graph._formula
+ new_formula.code = list(
+ map(
+ lambda x: x.op, # type: ignore
+ filter(lambda n: n not in new_graph.roots, new_graph.nodes),
+ )
+ )
+ casting = {
+ AdditionEFDFormula: ModifiedAdditionEFDFormula,
+ DoublingEFDFormula: ModifiedDoublingEFDFormula,
+ DifferentialAdditionEFDFormula: ModifiedDifferentialAdditionEFDFormula,
+ LadderEFDFormula: ModifiedLadderEFDFormula,
+ }
+ if new_formula.__class__ not in set(casting.values()):
+ new_formula.__class__ = casting[new_formula.__class__]
+ return new_formula # type: ignore
+
+ def networkx_graph(self) -> nx.DiGraph:
+ graph = nx.DiGraph()
+ vertices = list(range(len(self.nodes)))
+ graph.add_nodes_from(vertices)
+ stack = self.roots.copy()
+ while stack:
+ node = stack.pop()
+ for out in node.outgoing_nodes:
+ stack.append(out)
+ graph.add_edge(self.node_index(node), self.node_index(out))
+ return graph
+
+ def levels(self) -> List[List[Node]]:
+ stack = self.roots.copy()
+ levels = [(n, 0) for n in stack]
+ level_counter = 1
+ while stack:
+ tmp_stack = []
+ while stack:
+ node = stack.pop()
+ levels.append((node, level_counter))
+ for out in node.outgoing_nodes:
+ tmp_stack.append(out)
+ stack = tmp_stack
+ level_counter += 1
+ # separate into lists
+
+ level_lists: List[List[Node]] = [[] for _ in range(level_counter)]
+ discovered = []
+ for node, l in reversed(levels):
+ if node not in discovered:
+ level_lists[l].append(node)
+ discovered.append(node)
+ return level_lists
+
+ def output_nodes(self) -> List[Node]:
+ return list(filter(lambda x: x.output_node, self.nodes))
+
+ def planar_positions(self) -> Dict[int, Tuple[float, float]]:
+ positions = {}
+ for i, level in enumerate(self.levels()):
+ for j, node in enumerate(level):
+ positions[self.node_index(node)] = (
+ 0.1 * float(i**2) + 3 * float(j),
+ -6 * float(i),
+ )
+ return positions
+
+ def draw(self, filename: Optional[str] = None, figsize: Tuple[int, int] = (12, 12)):
+ gnx = self.networkx_graph()
+ pos = nx.rescale_layout_dict(self.planar_positions())
+ plt.figure(figsize=figsize)
+ colors = [self.nodes[n].color for n in gnx.nodes]
+ labels = {n: self.nodes[n].label for n in gnx.nodes}
+ nx.draw(gnx, pos, node_color=colors, node_size=2000, labels=labels)
+ if filename:
+ plt.savefig(filename)
+ plt.close()
+ else:
+ plt.show()
+
+ def find_all_paths(self) -> List[List[Node]]:
+ gnx = self.networkx_graph()
+ index_paths = []
+ for u in self.roots:
+ iu = self.node_index(u)
+ for v in self.output_nodes():
+ iv = self.node_index(v)
+ index_paths.extend(nx.all_simple_paths(gnx, iu, iv))
+ paths = []
+ for p in index_paths:
+ paths.append([self.nodes[v] for v in p])
+ return paths
+
+ def reorder(self):
+ self.nodes = sum(self.levels(), [])
+
+ def remove_node(self, node):
+ self.nodes.remove(node)
+ if node in self.roots:
+ self.roots.remove(node)
+ for in_node in node.incoming_nodes:
+ in_node.outgoing_nodes = list(
+ filter(lambda x: x != node, in_node.outgoing_nodes)
+ )
+ for out_node in node.outgoing_nodes:
+ out_node.incoming_nodes = list(
+ filter(lambda x: x != node, out_node.incoming_nodes)
+ )
+
+ def add_node(self, node):
+ if isinstance(node, ConstantNode):
+ self.roots.append(node)
+ self.nodes.append(node)
+
+ def reindex(self):
+ results: Dict[str, str] = {}
+ counter = 0
+ for node in self.nodes:
+ if not isinstance(node, CodeOpNode):
+ continue
+ op = node.op
+ result, left, operator, right = (
+ op.result,
+ op.left,
+ op.operator.op_str,
+ op.right,
+ )
+ if left in results:
+ left = results[left]
+ if right in results:
+ right = results[right]
+ if not node.output_node:
+ new_result = f"iv{counter}"
+ counter += 1
+ else:
+ new_result = result
+ opstr = f"{new_result} = {left if left is not None else ''}{operator}{right if right is not None else ''}"
+ results[result] = new_result
+ node.op = CodeOp(parse(opstr.replace("^", "**")))
+
+ def update(self):
+ self.reorder()
+ self.reindex()
+
+ def print(self):
+ for node in self.nodes:
+ print(node)
+
+
+def rename_ivs(formula: EFDFormula):
+ graph = EFDFormulaGraph(formula)
+ return graph.to_EFDFormula()
diff --git a/pyecsca/ec/formula/metrics.py b/pyecsca/ec/formula/metrics.py
new file mode 100644
index 0000000..a0e42eb
--- /dev/null
+++ b/pyecsca/ec/formula/metrics.py
@@ -0,0 +1,100 @@
+from public import public
+from ...sca.re.zvp import unroll_formula
+from .base import Formula
+import warnings
+from typing import Dict
+from operator import itemgetter, attrgetter
+from ..curve import EllipticCurve
+from ..context import DefaultContext, local
+
+
+@public
+def formula_ivs(formula: Formula):
+ one_unroll = unroll_formula(formula)
+ one_results = {}
+ for name, value in one_unroll:
+ if name in formula.outputs:
+ one_results[name] = value
+ one_polys = set(map(itemgetter(1), one_unroll))
+ return one_polys, set(one_results.values())
+
+
+@public
+def ivs_norm(one: Formula):
+ return formula_ivs(one)[0]
+
+
+@public
+def formula_similarity(one: Formula, other: Formula) -> Dict[str, float]:
+ if one.coordinate_model != other.coordinate_model:
+ warnings.warn("Mismatched coordinate model.")
+
+ one_polys, one_result_polys = formula_ivs(one)
+ other_polys, other_result_polys = formula_ivs(other)
+ return {
+ "output": len(one_result_polys.intersection(other_result_polys))
+ / max(len(one_result_polys), len(other_result_polys)),
+ "ivs": len(one_polys.intersection(other_polys))
+ / max(len(one_polys), len(other_polys)),
+ }
+
+
+@public
+def formula_similarity_abs(one: Formula, other: Formula) -> Dict[str, float]:
+ if one.coordinate_model != other.coordinate_model:
+ warnings.warn("Mismatched coordinate model.")
+
+ one_polys, one_result_polys = formula_ivs(one)
+ other_polys, other_result_polys = formula_ivs(other)
+
+ one_polys = set([f if f.LC() > 0 else -f for f in one_polys])
+ other_polys = set([f if f.LC() > 0 else -f for f in other_polys])
+
+ one_result_polys = set([f if f.LC() > 0 else -f for f in one_result_polys])
+ other_result_polys = set([f if f.LC() > 0 else -f for f in other_result_polys])
+ return {
+ "output": len(one_result_polys.intersection(other_result_polys))
+ / max(len(one_result_polys), len(other_result_polys)),
+ "ivs": len(one_polys.intersection(other_polys))
+ / max(len(one_polys), len(other_polys)),
+ }
+
+
+@public
+def formula_similarity_fuzz(
+ one: Formula, other: Formula, curve: EllipticCurve, samples: int = 1000
+) -> Dict[str, float]:
+ if one.coordinate_model != other.coordinate_model:
+ raise ValueError("Mismatched coordinate model.")
+
+ output_matches = 0.0
+ iv_matches = 0.0
+ for _ in range(samples):
+ Paff = curve.affine_random()
+ Qaff = curve.affine_random()
+ Raff = curve.affine_add(Paff, Qaff)
+ P = Paff.to_model(one.coordinate_model, curve)
+ Q = Qaff.to_model(one.coordinate_model, curve)
+ R = Raff.to_model(one.coordinate_model, curve)
+ inputs = (P, Q, R)[: one.num_inputs]
+ with local(DefaultContext()) as ctx:
+ res_one = one(curve.prime, *inputs, **curve.parameters)
+ action_one = ctx.actions.get_by_index([0])
+ ivs_one = set(
+ map(attrgetter("value"), sum(action_one[0].intermediates.values(), []))
+ )
+ with local(DefaultContext()) as ctx:
+ res_other = other(curve.prime, *inputs, **curve.parameters)
+ action_other = ctx.actions.get_by_index([0])
+ ivs_other = set(
+ map(attrgetter("value"), sum(action_other[0].intermediates.values(), []))
+ )
+ iv_matches += len(ivs_one.intersection(ivs_other)) / max(
+ len(ivs_one), len(ivs_other)
+ )
+ one_coords = set(res_one)
+ other_coords = set(res_other)
+ output_matches += len(one_coords.intersection(other_coords)) / max(
+ len(one_coords), len(other_coords)
+ )
+ return {"output": output_matches / samples, "ivs": iv_matches / samples}
diff --git a/pyecsca/ec/formula/partitions.py b/pyecsca/ec/formula/partitions.py
new file mode 100644
index 0000000..9ea108c
--- /dev/null
+++ b/pyecsca/ec/formula/partitions.py
@@ -0,0 +1,378 @@
+from typing import List, Any, Generator
+from ast import parse
+from ..op import OpType, CodeOp
+from .graph import (
+ EFDFormulaGraph,
+ CodeOpNode,
+ ConstantNode,
+ Node,
+)
+from .fliparoo import find_fliparoos, AddFliparoo, MulFliparoo
+from copy import deepcopy
+from .efd import EFDFormula
+
+
+def reduce_all_adds(formula: EFDFormula, rename=True) -> EFDFormula:
+ graph = EFDFormulaGraph(formula, rename=rename)
+ add_fliparoos = find_single_input_add_fliparoos(graph)
+ for add_fliparoo in add_fliparoos:
+ reduce_add_fliparoo(add_fliparoo, copy=False)
+ reduce_all_XplusX(graph)
+ mul_fliparoos = find_constant_mul_fliparoos(graph)
+ for mul_fliparoo in mul_fliparoos:
+ reduce_mul_fliparoo(mul_fliparoo, copy=False)
+ return graph.to_EFDFormula()
+
+
+def expand_all_muls(formula: EFDFormula, rename=True) -> EFDFormula:
+ graph = EFDFormulaGraph(formula, rename)
+ enodes = find_expansion_nodes(graph)
+ for enode in enodes:
+ expand_mul(graph, enode, copy=False)
+ return graph.to_EFDFormula()
+
+
+def expand_all_nopower2_muls(formula: EFDFormula, rename=True) -> EFDFormula:
+ graph = EFDFormulaGraph(formula, rename)
+ enodes = find_expansion_nodes(graph, nopower2=True)
+ for enode in enodes:
+ expand_mul(graph, enode, copy=False)
+ return graph.to_EFDFormula()
+
+
+def find_single_input_add_fliparoos(graph: EFDFormulaGraph) -> List[AddFliparoo]:
+ fliparoos = find_fliparoos(graph, AddFliparoo)
+ single_input_fliparoos = []
+ for fliparoo in fliparoos:
+ found = False
+ for i in range(len(fliparoo), 1, -1):
+ subfliparoo = fliparoo.slice(0, i)
+ if len(set(subfliparoo.input_nodes())) == 1:
+ found = True
+ break
+ if found:
+ s = subfliparoo.slice(0, i)
+ single_input_fliparoos.append(s)
+ return single_input_fliparoos
+
+
+def find_constant_mul_fliparoos(graph: EFDFormulaGraph) -> List[MulFliparoo]:
+ fliparoos = find_fliparoos(graph, MulFliparoo)
+ constant_mul_fliparoo = []
+ for fliparoo in fliparoos:
+ found = False
+ for i in range(len(fliparoo), 1, -1):
+ subfliparoo = fliparoo.slice(0, i)
+ nonconstant_inputs = list(
+ filter(
+ lambda x: not isinstance(x, ConstantNode), subfliparoo.input_nodes()
+ )
+ )
+ if len(nonconstant_inputs) != 1:
+ continue
+ inode = nonconstant_inputs[0]
+ if inode not in fliparoo.first.incoming_nodes:
+ continue
+ if not sum(
+ 1
+ for node in fliparoo.first.incoming_nodes
+ if isinstance(node, ConstantNode)
+ ):
+ continue
+ found = True
+ break
+ if found:
+ s = subfliparoo.slice(0, i)
+ constant_mul_fliparoo.append(s)
+ return constant_mul_fliparoo
+
+
+def find_expansion_nodes(graph: EFDFormulaGraph, nopower2=False) -> List[Node]:
+ expansion_nodes: List[Node] = []
+ for node in graph.nodes:
+ if not isinstance(node, CodeOpNode) or not node.is_mul:
+ continue
+ for par in node.incoming_nodes:
+ if isinstance(par, ConstantNode):
+ if nopower2 and is_power_of_2(par.value):
+ continue
+ expansion_nodes.append(node)
+ break
+ return expansion_nodes
+
+
+def is_power_of_2(n: int) -> bool:
+ while n > 1:
+ if n & 1 == 1:
+ return False
+ n >>= 1
+ return True
+
+
+def reduce_all_XplusX(graph: EFDFormulaGraph):
+ adds = find_all_XplusX(graph)
+ for node in adds:
+ reduce_XplusX(graph, node)
+ graph.update()
+
+
+def find_all_XplusX(graph) -> List[CodeOpNode]:
+ adds = []
+ for node in graph.nodes:
+ if not isinstance(node, CodeOpNode) or not node.is_add:
+ continue
+ if node.incoming_nodes[0] == node.incoming_nodes[1]:
+ adds.append(node)
+ return adds
+
+
+def reduce_XplusX(graph: EFDFormulaGraph, node: CodeOpNode):
+ inode = node.incoming_nodes[0]
+ const_node = ConstantNode(2)
+ node.incoming_nodes[1] = const_node
+ const_node.outgoing_nodes = [node]
+ graph.add_node(const_node)
+ inode.outgoing_nodes = list(filter(lambda x: x != node, inode.outgoing_nodes))
+ inode.outgoing_nodes.append(node)
+ opstr = f"{node.result} = {inode.result}{OpType.Mult.op_str}{const_node.value}"
+ node.op = CodeOp(parse(opstr))
+
+
+def reduce_mul_fliparoo(fliparoo: MulFliparoo, copy=True):
+ if copy:
+ fliparoo = fliparoo.deepcopy()
+
+ first, last = fliparoo.first, fliparoo.last
+ inode = next(
+ filter(lambda x: not isinstance(x, ConstantNode), first.incoming_nodes)
+ )
+ const_nodes: List[ConstantNode] = [node for node in fliparoo.input_nodes() if isinstance(node, ConstantNode)]
+ sum_const_node = ConstantNode(sum(v.value for v in const_nodes))
+ fliparoo.graph.add_node(sum_const_node)
+
+ inode.outgoing_nodes = [n if n != first else last for n in inode.outgoing_nodes]
+ last.incoming_nodes = [inode, sum_const_node]
+ sum_const_node.outgoing_nodes = [last]
+
+ opstr = f"{last.result} = {inode.result}{OpType.Mult.op_str}{sum_const_node.value}"
+ last.op = CodeOp(parse(opstr))
+
+ for node in fliparoo:
+ if node == last:
+ continue
+ fliparoo.graph.remove_node(node)
+
+ for node in const_nodes:
+ if not node.outgoing_nodes:
+ fliparoo.graph.remove_node(node)
+
+ fliparoo.graph.update()
+
+ return fliparoo.graph
+
+
+def reduce_add_fliparoo(fliparoo: AddFliparoo, copy=True):
+ if copy:
+ fliparoo = fliparoo.deepcopy()
+ first, last = fliparoo.first, fliparoo.last
+ par = first.incoming_nodes[0]
+ const_node = ConstantNode(len(fliparoo) + 1)
+ fliparoo.graph.add_node(const_node)
+ mul_node = CodeOpNode.from_str(
+ last.result, const_node.result, OpType.Mult, par.result
+ )
+ fliparoo.graph.add_node(mul_node)
+ mul_node.incoming_nodes = [const_node, par]
+ par.outgoing_nodes.append(mul_node)
+ const_node.outgoing_nodes.append(mul_node)
+ mul_node.output_node = last.output_node
+ last.reconnect_outgoing_nodes(mul_node)
+ for node in fliparoo:
+ fliparoo.graph.remove_node(node)
+
+ fliparoo.graph.update()
+
+ return fliparoo.graph
+
+
+def expand_mul(graph: EFDFormulaGraph, node: Node, copy=True) -> EFDFormulaGraph:
+ if copy:
+ i = graph.node_index(node)
+ graph = deepcopy(graph)
+ node = graph.nodes[i]
+
+ const_par = next(filter(lambda x: isinstance(x, ConstantNode), node.incoming_nodes))
+ par = next(filter(lambda x: not isinstance(x, ConstantNode), node.incoming_nodes))
+ initial_node = CodeOpNode.from_str(node.result, par.result, OpType.Add, par.result)
+ graph.add_node(initial_node)
+ initial_node.incoming_nodes = [par, par]
+ par.outgoing_nodes.extend([initial_node, initial_node])
+ prev_node = initial_node
+ for _ in range(const_par.value - 2):
+ anode = CodeOpNode.from_str(
+ node.result, prev_node.result, OpType.Add, par.result
+ )
+ anode.incoming_nodes = [prev_node, par]
+ par.outgoing_nodes.append(anode)
+ graph.add_node(anode)
+ prev_node.outgoing_nodes = [anode]
+ prev_node = anode
+
+ prev_node.output_node = node.output_node
+ node.reconnect_outgoing_nodes(prev_node)
+ graph.remove_node(node)
+ graph.remove_node(const_par)
+ graph.update()
+
+ return graph
+
+
+class Partition:
+ value: int
+ parts: List["Partition"]
+
+ def __init__(self, n: int):
+ self.value = n
+ self.parts = []
+
+ @property
+ def is_final(self):
+ return not self.parts
+
+ def __repr__(self):
+ if self.is_final:
+ return f"({self.value})"
+ l, r = self.parts
+ return f"({l.__repr__()},{r.__repr__()})"
+
+ def __add__(self, other):
+ a = Partition(self.value + other.value)
+ a.parts = [self, other]
+ return a
+
+ def __eq__(self, other):
+ if self.value != other.value:
+ return False
+ if self.is_final or other.is_final:
+ return self.is_final == other.is_final
+ l, r = self.parts
+ lo, ro = other.parts
+ return (l == lo and r == ro) or (l == ro and r == lo)
+
+
+def compute_partitions(n: int) -> List[Partition]:
+ partitions = [Partition(n)]
+ for d in range(1, n // 2 + 1):
+ n_d = n - d
+ for partition_dp in compute_partitions(d):
+ for partition_n_dp in compute_partitions(n_d):
+ partitions.append(partition_dp + partition_n_dp)
+ # remove duplicates
+ result = []
+ for p in partitions:
+ if p not in result:
+ result.append(p)
+ return result
+
+
+def generate_partitioned_formulas(formula: EFDFormula, rename=True):
+ graph = EFDFormulaGraph(formula, rename)
+ enodes = find_expansion_nodes(graph)
+ for enode in enodes:
+ for part_graph in generate_all_node_partitions(graph, enode):
+ yield part_graph.to_EFDFormula()
+
+
+def generate_all_node_partitions(
+ original_graph: EFDFormulaGraph, node: Node
+) -> Generator[EFDFormulaGraph, Any, None]:
+ const_par = next(filter(lambda x: isinstance(x, ConstantNode), node.incoming_nodes))
+ const_par_value = const_par.value
+
+ par = next(filter(lambda x: not isinstance(x, ConstantNode), node.incoming_nodes))
+ i, ic, ip = (
+ original_graph.node_index(node),
+ original_graph.node_index(const_par),
+ original_graph.node_index(par),
+ )
+
+ for partition in compute_partitions(const_par_value):
+ if partition.is_final:
+ continue
+
+ # copy
+ graph = deepcopy(original_graph)
+ node, const_par, par = graph.nodes[i], graph.nodes[ic], graph.nodes[ip]
+ graph.remove_node(const_par)
+ lresult, rresult = f"{node.result}L", f"{node.result}R"
+ empty_left_node = CodeOpNode.from_str(lresult, "PART", OpType.Add, "PART")
+ empty_right_node = CodeOpNode.from_str(rresult, "PART", OpType.Add, "PART")
+ addition_node = CodeOpNode.from_str(node.result, lresult, OpType.Add, rresult)
+ graph.add_node(empty_left_node)
+ graph.add_node(empty_right_node)
+ graph.add_node(addition_node)
+ addition_node.outgoing_nodes = node.outgoing_nodes
+ addition_node.output_node = node.output_node
+ addition_node.incoming_nodes = [empty_left_node, empty_right_node]
+ empty_left_node.outgoing_nodes = [addition_node]
+ empty_right_node.outgoing_nodes = [addition_node]
+
+ left, right = partition.parts
+ partition_node(graph, empty_left_node, left, par)
+ partition_node(graph, empty_right_node, right, par)
+
+ graph.remove_node(node)
+ graph.update()
+ yield graph
+
+
+def partition_node(
+ graph: EFDFormulaGraph, node: CodeOpNode, partition: Partition, source_node: Node
+):
+ if partition.is_final and partition.value == 1:
+ # source node will take the role of node
+ # note: node has precisely one output node, since it was created during partitions
+ assert len(node.outgoing_nodes) == 1
+ child = node.outgoing_nodes[0]
+ source_node.outgoing_nodes.append(child)
+
+ left, right = child.incoming_nodes[0].result, child.incoming_nodes[1].result
+ if child.incoming_nodes[0] == node:
+ left = source_node.result
+ child.incoming_nodes[0] = source_node
+ else:
+ right = source_node.result
+ child.incoming_nodes[1] = source_node
+ opstr = f"{child.result} = {left}{child.optype.op_str}{right}"
+ child.op = CodeOp(parse(opstr))
+ graph.remove_node(node)
+ return
+
+ if partition.is_final:
+ source_node.outgoing_nodes.append(node)
+ const_node = ConstantNode(partition.value)
+ graph.add_node(const_node)
+ opstr = (
+ f"{node.result} = {source_node.result}{OpType.Mult.op_str}{partition.value}"
+ )
+ node.op = CodeOp(parse(opstr))
+ node.incoming_nodes = [source_node, const_node]
+ const_node.outgoing_nodes = [node]
+ return
+
+ lresult, rresult = f"{node.result}L", f"{node.result}R"
+ empty_left_node = CodeOpNode.from_str(lresult, "PART", OpType.Add, "PART")
+ empty_right_node = CodeOpNode.from_str(rresult, "PART", OpType.Add, "PART")
+
+ opstr = f"{node.result} = {lresult}{OpType.Add.op_str}{rresult}"
+ node.op = CodeOp(parse(opstr))
+ graph.add_node(empty_left_node)
+ graph.add_node(empty_right_node)
+
+ node.incoming_nodes = [empty_left_node, empty_right_node]
+ empty_left_node.outgoing_nodes = [node]
+ empty_right_node.outgoing_nodes = [node]
+
+ left, right = partition.parts
+ partition_node(graph, empty_left_node, left, source_node)
+ partition_node(graph, empty_right_node, right, source_node)
diff --git a/pyecsca/ec/formula/switch_sign.py b/pyecsca/ec/formula/switch_sign.py
new file mode 100644
index 0000000..1acef5b
--- /dev/null
+++ b/pyecsca/ec/formula/switch_sign.py
@@ -0,0 +1,131 @@
+from typing import Dict, Iterator, List, Any
+from ast import parse
+from ..op import OpType, CodeOp
+from .graph import EFDFormulaGraph, ConstantNode, Node, CodeOpNode
+from itertools import chain, combinations
+from .efd import EFDFormula
+from ..point import Point
+from ..mod import Mod
+
+
+def generate_switched_formulas(
+ formula: EFDFormula, rename=True
+) -> Iterator[EFDFormula]:
+ graph = EFDFormulaGraph(formula, rename)
+ for node_combination in subnode_lists(graph):
+ try:
+ yield switch_sign(graph, node_combination).to_EFDFormula()
+ except BadSignSwitch:
+ continue
+
+
+def subnode_lists(graph: EFDFormulaGraph):
+ return powerlist(filter(lambda x: x not in graph.roots and x.is_sub, graph.nodes))
+
+
+def switch_sign(graph: EFDFormulaGraph, node_combination) -> EFDFormulaGraph:
+ nodes_i = [graph.node_index(node) for node in node_combination]
+ graph = graph.deepcopy()
+ node_combination = set(graph.nodes[node_i] for node_i in nodes_i)
+ output_signs = {out: 1 for out in graph.output_names}
+
+ queue = []
+ for node in node_combination:
+ change_sides(node)
+ if node.output_node:
+ output_signs[node.result] = -1
+ queue.extend([(out, node.result) for out in node.outgoing_nodes])
+
+ while queue:
+ node, variable = queue.pop()
+ queue = switch_sign_propagate(node, variable, output_signs) + queue
+
+ sign_test(output_signs, graph.coordinate_model)
+ return graph
+
+
+def sign_test(output_signs: Dict[str, int], coordinate_model: Any):
+ scale = coordinate_model.formulas.get("z", None)
+ if scale is None:
+ scale = coordinate_model.formulas.get("scale", None)
+ p = 7
+ out_inds = set(map(lambda x: "".join([o for o in x if o.isdigit()]), output_signs))
+ for ind in out_inds:
+ point_dict = {}
+ for out, sign in output_signs.items():
+ if not out.endswith(ind):
+ continue
+ out_var = out[:out.index(ind)]
+ if not out_var.isalpha():
+ continue
+ point_dict[out_var] = Mod(sign, p)
+ point = Point(coordinate_model, **point_dict)
+ try:
+ apoint = point.to_affine()
+ except NotImplementedError:
+ apoint = scale(p, point)[0]
+ if set(apoint.coords.values()) != set([Mod(1, p)]):
+ raise BadSignSwitch
+
+
+class BadSignSwitch(Exception):
+ pass
+
+
+def switch_sign_propagate(
+ node: CodeOpNode, variable: str, output_signs: Dict[str, int]
+):
+ if node.is_add:
+ if variable == node.incoming_nodes[1].result:
+ node.op = change_operator(node.op, OpType.Sub)
+ return []
+ change_sides(node)
+ node.op = change_operator(node.op, OpType.Sub)
+ return []
+ if node.is_id or node.is_neg:
+ output_signs[node.result] *= -1
+ return [(child, node.result) for child in node.outgoing_nodes]
+
+ if node.is_sqr:
+ return []
+ if node.is_sub:
+ if node.incoming_nodes[0].result == variable:
+ node.op = change_operator(node.op, OpType.Add)
+ if node.output_node:
+ output_signs[node.result] *= -1
+ return [(child, node.result) for child in node.outgoing_nodes]
+ node.op = change_operator(node.op, OpType.Add)
+ return []
+ if node.is_pow:
+ exponent = next(
+ filter(lambda n: isinstance(n, ConstantNode), node.incoming_nodes)
+ )
+ if exponent.value % 2 == 0:
+ return []
+ if node.output_node:
+ output_signs[node.result] *= -1
+ assert node.is_mul or node.is_pow or node.is_inv or node.is_div
+ return [(child, node.result) for child in node.outgoing_nodes]
+
+
+def change_operator(op, new_operator):
+ result, left, right = op.result, op.left, op.right
+ opstr = f"{result} = {left if left is not None else ''}{new_operator.op_str}{right if right is not None else ''}"
+ return CodeOp(parse(opstr.replace("^", "**")))
+
+
+def change_sides(node):
+ op = node.op
+ result, left, operator, right = op.result, op.left, op.operator.op_str, op.right
+ left, right = right, left
+ opstr = f"{result} = {left if left is not None else ''}{operator}{right if right is not None else ''}"
+ node.op = CodeOp(parse(opstr.replace("^", "**")))
+ node.incoming_nodes[1], node.incoming_nodes[0] = (
+ node.incoming_nodes[0],
+ node.incoming_nodes[1],
+ )
+
+
+def powerlist(iterable: Iterator) -> List:
+ s = list(iterable)
+ return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))
diff --git a/pyecsca/ec/mult/window.py b/pyecsca/ec/mult/window.py
index f85a58a..d025cc1 100644
--- a/pyecsca/ec/mult/window.py
+++ b/pyecsca/ec/mult/window.py
@@ -4,15 +4,22 @@ from typing import Optional, MutableMapping
from public import public
from ..params import DomainParameters
-from .base import ScalarMultiplier, AccumulationOrder, ScalarMultiplicationAction, PrecomputationAction, \
- ProcessingDirection, AccumulatorMultiplier
+from .base import (
+ ScalarMultiplier,
+ AccumulationOrder,
+ ScalarMultiplicationAction,
+ PrecomputationAction,
+ ProcessingDirection,
+ AccumulatorMultiplier,
+)
from ..formula import (
AdditionFormula,
DoublingFormula,
ScalingFormula,
+ NegationFormula,
)
from ..point import Point
-from ..scalar import convert_base, sliding_window_rtl, sliding_window_ltr
+from ..scalar import convert_base, sliding_window_rtl, sliding_window_ltr, booth_window
@public
@@ -34,17 +41,21 @@ class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier):
_points: MutableMapping[int, Point]
def __init__(
- self,
- add: AdditionFormula,
- dbl: DoublingFormula,
- width: int,
- scl: Optional[ScalingFormula] = None,
- recoding_direction: ProcessingDirection = ProcessingDirection.LTR,
- accumulation_order: AccumulationOrder = AccumulationOrder.PeqPR,
- short_circuit: bool = True,
+ self,
+ add: AdditionFormula,
+ dbl: DoublingFormula,
+ width: int,
+ scl: Optional[ScalingFormula] = None,
+ recoding_direction: ProcessingDirection = ProcessingDirection.LTR,
+ accumulation_order: AccumulationOrder = AccumulationOrder.PeqPR,
+ short_circuit: bool = True,
):
super().__init__(
- short_circuit=short_circuit, accumulation_order=accumulation_order, add=add, dbl=dbl, scl=scl
+ short_circuit=short_circuit,
+ accumulation_order=accumulation_order,
+ add=add,
+ dbl=dbl,
+ scl=scl,
)
self.width = width
self.recoding_direction = recoding_direction
@@ -55,7 +66,13 @@ class SlidingWindowMultiplier(AccumulatorMultiplier, ScalarMultiplier):
def __eq__(self, other):
if not isinstance(other, SlidingWindowMultiplier):
return False
- return self.formulas == other.formulas and self.short_circuit == other.short_circuit and self.width == other.width and self.recoding_direction == other.recoding_direction and self.accumulation_order == other.accumulation_order
+ return (
+ self.formulas == other.formulas
+ and self.short_circuit == other.short_circuit
+ and self.width == other.width
+ and self.recoding_direction == other.recoding_direction
+ and self.accumulation_order == other.accumulation_order
+ )
def __repr__(self):
return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, width={self.width}, recoding_direction={self.recoding_direction.name}, accumulation_order={self.accumulation_order.name})"
@@ -112,16 +129,20 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier):
_points: MutableMapping[int, Point]
def __init__(
- self,
- add: AdditionFormula,
- dbl: DoublingFormula,
- m: int,
- scl: Optional[ScalingFormula] = None,
- accumulation_order: AccumulationOrder = AccumulationOrder.PeqPR,
- short_circuit: bool = True,
+ self,
+ add: AdditionFormula,
+ dbl: DoublingFormula,
+ m: int,
+ scl: Optional[ScalingFormula] = None,
+ accumulation_order: AccumulationOrder = AccumulationOrder.PeqPR,
+ short_circuit: bool = True,
):
super().__init__(
- short_circuit=short_circuit, accumulation_order=accumulation_order, add=add, dbl=dbl, scl=scl
+ short_circuit=short_circuit,
+ accumulation_order=accumulation_order,
+ add=add,
+ dbl=dbl,
+ scl=scl,
)
if m < 2:
raise ValueError("Invalid base.")
@@ -134,7 +155,12 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier):
def __eq__(self, other):
if not isinstance(other, FixedWindowLTRMultiplier):
return False
- return self.formulas == other.formulas and self.short_circuit == other.short_circuit and self.m == other.m and self.accumulation_order == other.accumulation_order
+ return (
+ self.formulas == other.formulas
+ and self.short_circuit == other.short_circuit
+ and self.m == other.m
+ and self.accumulation_order == other.accumulation_order
+ )
def __repr__(self):
return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, m={self.m}, accumulation_order={self.accumulation_order.name})"
@@ -180,3 +206,103 @@ class FixedWindowLTRMultiplier(AccumulatorMultiplier, ScalarMultiplier):
if "scl" in self.formulas:
q = self._scl(q)
return action.exit(q)
+
+
+@public
+class WindowBoothMultiplier(AccumulatorMultiplier, ScalarMultiplier):
+ """
+
+ :param short_circuit: Whether the use of formulas will be guarded by short-circuit on inputs
+ of the point at infinity.
+ :param width: The width of the window.
+ :param accumulation_order: The order of accumulation of points.
+ :param precompute_negation: Whether to precompute the negation of the precomputed points as well.
+ It is computed on the fly otherwise.
+ """
+
+ requires = {AdditionFormula, DoublingFormula, NegationFormula}
+ optionals = {ScalingFormula}
+ _points: MutableMapping[int, Point]
+ _points_neg: MutableMapping[int, Point]
+ precompute_negation: bool = False
+ """Whether to precompute the negation of the precomputed points as well."""
+ width: int
+ """The width of the window."""
+
+ def __init__(
+ self,
+ add: AdditionFormula,
+ dbl: DoublingFormula,
+ neg: NegationFormula,
+ width: int,
+ scl: Optional[ScalingFormula] = None,
+ accumulation_order: AccumulationOrder = AccumulationOrder.PeqPR,
+ precompute_negation: bool = False,
+ short_circuit: bool = True,
+ ):
+ super().__init__(
+ short_circuit=short_circuit,
+ accumulation_order=accumulation_order,
+ add=add,
+ dbl=dbl,
+ neg=neg,
+ scl=scl,
+ )
+ self.width = width
+ self.precompute_negation = precompute_negation
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ if not isinstance(other, WindowBoothMultiplier):
+ return False
+ return (
+ self.formulas == other.formulas
+ and self.short_circuit == other.short_circuit
+ and self.width == other.width
+ and self.precompute_negation == other.precompute_negation
+ and self.accumulation_order == other.accumulation_order
+ )
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({', '.join(map(str, self.formulas.values()))}, short_circuit={self.short_circuit}, width={self.width}, precompute_negation={self.precompute_negation}, accumulation_order={self.accumulation_order.name})"
+
+ def init(self, params: DomainParameters, point: Point):
+ with PrecomputationAction(params, point):
+ super().init(params, point)
+ double_point = self._dbl(point)
+ self._points = {1: point, 2: double_point}
+ if self.precompute_negation:
+ self._points_neg = {1: self._neg(point), 2: self._neg(double_point)}
+ current_point = double_point
+ for i in range(3, 2 ** (self.width - 1) + 1):
+ current_point = self._add(current_point, point)
+ self._points[i] = current_point
+ if self.precompute_negation:
+ self._points_neg[i] = self._neg(current_point)
+
+ def multiply(self, scalar: int) -> Point:
+ if not self._initialized:
+ raise ValueError("ScalarMultiplier not initialized.")
+ with ScalarMultiplicationAction(self._point, scalar) as action:
+ if scalar == 0:
+ return action.exit(copy(self._params.curve.neutral))
+ scalar_booth = booth_window(
+ scalar, self.width, self._params.order.bit_length()
+ )
+ q = copy(self._params.curve.neutral)
+ for val in scalar_booth:
+ for _ in range(self.width):
+ q = self._dbl(q)
+ if val > 0:
+ q = self._accumulate(q, self._points[val])
+ elif val < 0:
+ if self.precompute_negation:
+ neg = self._points_neg[-val]
+ else:
+ neg = self._neg(self._points[-val])
+ q = self._accumulate(q, neg)
+ if "scl" in self.formulas:
+ q = self._scl(q)
+ return action.exit(q)
diff --git a/pyecsca/ec/point.py b/pyecsca/ec/point.py
index 280d746..bdbba5e 100644
--- a/pyecsca/ec/point.py
+++ b/pyecsca/ec/point.py
@@ -140,7 +140,7 @@ class Point:
if randomized:
lmbd = Mod.random(curve.prime)
for var, value in result.items():
- result[var] = value * lmbd**coordinate_model.homogweights[var]
+ result[var] = value * (lmbd**coordinate_model.homogweights[var])
return action.exit(Point(coordinate_model, **result))
def equals_affine(self, other: "Point") -> bool:
diff --git a/pyecsca/ec/scalar.py b/pyecsca/ec/scalar.py
index 3fadc00..af5a6ab 100644
--- a/pyecsca/ec/scalar.py
+++ b/pyecsca/ec/scalar.py
@@ -1,5 +1,5 @@
"""Provides functions for computing various scalar representations (like NAF, or different bases)."""
-from typing import List
+from typing import List, Tuple, Literal
from itertools import dropwhile
from public import public
@@ -11,7 +11,7 @@ def convert_base(i: int, base: int) -> List[int]:
:param i: The scalar.
:param base: The base.
- :return: The resulting digit list.
+ :return: The resulting digit list (little-endian).
"""
if i == 0:
return [0]
@@ -131,3 +131,62 @@ def naf(k: int) -> List[int]:
:return: The NAF.
"""
return wnaf(k, 2)
+
+
+@public
+def booth(k: int) -> List[int]:
+ """
+ Original Booth binary recoding, from [B51]_.
+
+ :param k: The scalar to recode.
+ :return: The recoded list of digits (0, 1, -1), little-endian.
+ """
+ res = []
+ for i in range(k.bit_length()):
+ a_i = (k >> i) & 1
+ b_i = (k >> (i + 1)) & 1
+ res.append(a_i - b_i)
+ res.insert(0, -(k & 1))
+ return res
+
+
+@public
+def booth_word(digit: int, w: int) -> int:
+ """
+ Modified Booth recoding, from [M61]_ and BoringSSL NIST impl.
+
+ Needs `w+1` bits of scalar in digit, but the one bit is overlapping (window size is `w`).
+
+ :param digit:
+ :param w:
+ :return:
+ """
+ if digit.bit_length() > (w + 1):
+ raise ValueError("Invalid digit, cannot be larger than w + 1 bits.")
+ s = ~((digit >> w) - 1)
+ d = (1 << (w + 1)) - digit - 1
+ d = (d & s) | (digit & ~s)
+ d = (d >> 1) + (d & 1)
+
+ return -d if s else d
+
+
+@public
+def booth_window(k: int, w: int, blen: int) -> List[int]:
+ """
+ Recode a whole scalar using Booth recoding as in BoringSSL.
+
+ :param k: The scalar.
+ :param w: The window size.
+ :param blen: The bit-length of the group.
+ :return: The big-endian recoding
+ """
+ mask = (1 << (w + 1)) - 1
+ res = []
+ for i in range(blen + (w - (blen % w) - 1), -1, -w):
+ if i >= w:
+ d = (k >> (i - w)) & mask
+ else:
+ d = (k << (w - i)) & mask
+ res.append(booth_word(d, w))
+ return res