aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJ08nY2024-08-26 15:38:58 +0200
committerJ08nY2024-08-26 15:38:58 +0200
commit481f567792d05e6a4e09121a7358252b11706ee8 (patch)
tree654903eeefc005716e1f84ae41a071d532affb34
parentfcb17bf74e4cc6acedbad707a9ddf4e5a59e8a0f (diff)
downloadpyecsca-481f567792d05e6a4e09121a7358252b11706ee8.tar.gz
pyecsca-481f567792d05e6a4e09121a7358252b11706ee8.tar.zst
pyecsca-481f567792d05e6a4e09121a7358252b11706ee8.zip
Add parallel version of formula expansion.
-rw-r--r--pyecsca/ec/formula/expand.py86
-rw-r--r--pyecsca/ec/formula/fliparoo.py4
-rw-r--r--pyecsca/ec/formula/switch_sign.py12
-rw-r--r--test/ec/test_formula.py9
4 files changed, 97 insertions, 14 deletions
diff --git a/pyecsca/ec/formula/expand.py b/pyecsca/ec/formula/expand.py
index eff227d..7e9b5ee 100644
--- a/pyecsca/ec/formula/expand.py
+++ b/pyecsca/ec/formula/expand.py
@@ -1,6 +1,6 @@
"""Provides a formula expansion function."""
-from typing import Set, Callable, Any, List, Iterable
+from typing import Set, Callable, Any, List, Iterable, Optional
from public import public
from operator import attrgetter
from functools import lru_cache
@@ -14,12 +14,14 @@ from pyecsca.ec.formula.partitions import (
expand_all_muls,
expand_all_nopower2_muls,
)
-from pyecsca.ec.formula.switch_sign import generate_switched_formulas
+from pyecsca.ec.formula.switch_sign import generate_switched_formulas, switch_signs
+from pyecsca.misc.utils import TaskExecutor
def reduce_with_similarity(
formulas: List[Formula], norm: Callable[[Formula], Any]
) -> List[Formula]:
+ formulas = reduce_by_eq(formulas)
reduced = list(filter(lambda x: isinstance(x, EFDFormula), formulas))
similarities = list(map(norm, reduced))
for formula in formulas:
@@ -51,10 +53,11 @@ def expand_formula_set(
- Sign switching
- Associativity and Commutativity
- :param formulas:
- :param norm:
- :return:
+ :param formulas: The set of formulas to expand.
+ :param norm: The norm to use while reducing.
+ :return: The expanded set of formulas.
"""
+
@lru_cache(maxsize=1000)
def cached(formula):
return norm(formula)
@@ -62,11 +65,15 @@ def expand_formula_set(
extended = sorted(formulas, key=attrgetter("name"))
extended = reduce_with_similarity(extended, cached)
- fliparood: List[Formula] = reduce_by_eq(sum(map(lambda f: reduce_by_eq(recursive_fliparoo(f)), extended), []))
+ fliparood: List[Formula] = reduce_by_eq(
+ sum(map(lambda f: reduce_by_eq(recursive_fliparoo(f)), extended), [])
+ )
extended.extend(fliparood)
extended = reduce_with_similarity(extended, cached)
- switch_signs: List[Formula] = reduce_by_eq(sum(map(lambda f: reduce_by_eq(generate_switched_formulas(f)), extended), []))
+ switch_signs: List[Formula] = reduce_by_eq(
+ sum(map(lambda f: reduce_by_eq(generate_switched_formulas(f)), extended), [])
+ )
extended.extend(switch_signs)
extended = reduce_with_similarity(extended, cached)
@@ -78,8 +85,71 @@ def expand_formula_set(
extended.extend(mul_expanded)
extended = reduce_with_similarity(extended, cached)
- np2_expanded: List[Formula] = reduce_by_eq(list(map(expand_all_nopower2_muls, extended)))
+ np2_expanded: List[Formula] = reduce_by_eq(
+ list(map(expand_all_nopower2_muls, extended))
+ )
extended.extend(np2_expanded)
extended = reduce_with_similarity(extended, cached)
return set(reduce_by_eq(extended))
+
+
+@public
+def expand_formula_set_parallel(
+ formulas: Set[Formula],
+ norm: Callable[[Formula], Any] = ivs_norm,
+ num_workers: int = 1,
+) -> Set[Formula]:
+ """
+ Expand a set of formulas by using transformations (parallelized):
+ - Fliparoos
+ - Sign switching
+ - Associativity and Commutativity
+
+ :param formulas: The set of formulas to expand.
+ :param norm: The norm to use while reducing.
+ :param num_workers: The amount of workers to use.
+ :return: The expanded set of formulas.
+ """
+
+ @lru_cache(maxsize=1000)
+ def cached(formula):
+ return norm(formula)
+
+ def map_multiple(pool, formulas, fn):
+ results: List[List[Formula]] = []
+ for f in formulas:
+ pool.submit_task(f, fn, f)
+ results.append([])
+ for f, future in pool.as_completed():
+ results[formulas.index(f)] = reduce_by_eq(future.result())
+ return reduce_by_eq(sum(results, []))
+
+ def map_single(pool, formulas, fn):
+ return reduce_by_eq(pool.map(fn, formulas))
+
+ extended = sorted(formulas, key=attrgetter("name"))
+ extended = reduce_with_similarity(extended, cached)
+
+ with TaskExecutor(max_workers=num_workers) as pool:
+ fliparood = map_multiple(pool, extended, recursive_fliparoo)
+ extended.extend(fliparood)
+ extended = reduce_with_similarity(extended, cached)
+
+ switched = map_multiple(pool, extended, switch_signs)
+ extended.extend(switched)
+ extended = reduce_with_similarity(extended, cached)
+
+ add_reduced = map_single(pool, extended, reduce_all_adds)
+ extended.extend(add_reduced)
+ extended = reduce_with_similarity(extended, cached)
+
+ mul_expanded = map_single(pool, extended, expand_all_muls)
+ extended.extend(mul_expanded)
+ extended = reduce_with_similarity(extended, cached)
+
+ np2_expanded = map_single(pool, extended, expand_all_nopower2_muls)
+ extended.extend(np2_expanded)
+ extended = reduce_with_similarity(extended, cached)
+
+ return set(reduce_by_eq(extended))
diff --git a/pyecsca/ec/formula/fliparoo.py b/pyecsca/ec/formula/fliparoo.py
index fd744b5..4d14f80 100644
--- a/pyecsca/ec/formula/fliparoo.py
+++ b/pyecsca/ec/formula/fliparoo.py
@@ -1,6 +1,6 @@
"""Provides a way to Fliparoo formulas."""
from ast import parse
-from typing import Iterator, List, Type, Optional, Set
+from typing import Iterator, List, Type, Optional
from public import public
from pyecsca.ec.op import OpType
@@ -254,7 +254,7 @@ def generate_fliparood_formulas(
fliparoos = find_fliparoos(graph)
for i, fliparoo in enumerate(fliparoos):
for j, flip_graph in enumerate(generate_fliparood_graphs(fliparoo)):
- yield flip_graph.to_formula(f"fliparoo[{i},{j}]")
+ yield flip_graph.to_formula(f"fliparoo[{i},{j}]") # noqa
def generate_fliparood_graphs(fliparoo: Fliparoo) -> Iterator[FormulaGraph]:
diff --git a/pyecsca/ec/formula/switch_sign.py b/pyecsca/ec/formula/switch_sign.py
index 6fede9d..95c18be 100644
--- a/pyecsca/ec/formula/switch_sign.py
+++ b/pyecsca/ec/formula/switch_sign.py
@@ -6,9 +6,9 @@ from itertools import chain, combinations
from pyecsca.ec.op import OpType, CodeOp
from pyecsca.ec.formula.base import Formula
-from pyecsca.ec.formula.graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula, Node
+from pyecsca.ec.formula.graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula
from pyecsca.ec.point import Point
-from pyecsca.ec.mod import Mod, mod
+from pyecsca.ec.mod import mod
@public
@@ -21,11 +21,17 @@ def generate_switched_formulas(formula: Formula, rename=True) -> Iterator[CodeFo
continue
+def switch_signs(formula: Formula, rename=True) -> List[CodeFormula]:
+ return list(generate_switched_formulas(formula, rename))
+
+
def subnode_lists(graph: FormulaGraph) -> List[List[CodeOpNode]]:
return powerlist(filter(lambda x: x not in graph.roots and x.is_sub, graph.nodes))
-def switch_sign(graph: FormulaGraph, node_combination: List[CodeOpNode]) -> FormulaGraph:
+def switch_sign(
+ graph: FormulaGraph, node_combination: List[CodeOpNode]
+) -> FormulaGraph:
nodes_i = [graph.node_index(node) for node in node_combination]
graph = graph.deepcopy()
node_combination = [graph.nodes[node_i] for node_i in nodes_i] # type: ignore
diff --git a/test/ec/test_formula.py b/test/ec/test_formula.py
index e5fce6e..c0350cb 100644
--- a/test/ec/test_formula.py
+++ b/test/ec/test_formula.py
@@ -6,7 +6,7 @@ import pytest
from sympy import FF, symbols
from importlib_resources import files, as_file
import pyecsca.ec
-from pyecsca.ec.formula.expand import expand_formula_set
+from pyecsca.ec.formula.expand import expand_formula_set, expand_formula_set_parallel
from pyecsca.ec.formula.fliparoo import generate_fliparood_formulas
from pyecsca.ec.formula.graph import rename_ivs
from pyecsca.ec.formula.metrics import (
@@ -543,3 +543,10 @@ def test_formula_correctness(library_formula_params):
def test_formula_expand(add):
res = expand_formula_set({add})
assert len(res) > 1
+
+
+def test_formula_expand_parallel(add):
+ res = expand_formula_set_parallel({add})
+ assert len(res) > 1
+ other = expand_formula_set({add})
+ assert res == other