diff options
| author | J08nY | 2024-08-26 15:38:58 +0200 |
|---|---|---|
| committer | J08nY | 2024-08-26 15:38:58 +0200 |
| commit | 481f567792d05e6a4e09121a7358252b11706ee8 (patch) | |
| tree | 654903eeefc005716e1f84ae41a071d532affb34 | |
| parent | fcb17bf74e4cc6acedbad707a9ddf4e5a59e8a0f (diff) | |
| download | pyecsca-481f567792d05e6a4e09121a7358252b11706ee8.tar.gz pyecsca-481f567792d05e6a4e09121a7358252b11706ee8.tar.zst pyecsca-481f567792d05e6a4e09121a7358252b11706ee8.zip | |
Add parallel version of formula expansion.
| -rw-r--r-- | pyecsca/ec/formula/expand.py | 86 | ||||
| -rw-r--r-- | pyecsca/ec/formula/fliparoo.py | 4 | ||||
| -rw-r--r-- | pyecsca/ec/formula/switch_sign.py | 12 | ||||
| -rw-r--r-- | test/ec/test_formula.py | 9 |
4 files changed, 97 insertions, 14 deletions
diff --git a/pyecsca/ec/formula/expand.py b/pyecsca/ec/formula/expand.py index eff227d..7e9b5ee 100644 --- a/pyecsca/ec/formula/expand.py +++ b/pyecsca/ec/formula/expand.py @@ -1,6 +1,6 @@ """Provides a formula expansion function.""" -from typing import Set, Callable, Any, List, Iterable +from typing import Set, Callable, Any, List, Iterable, Optional from public import public from operator import attrgetter from functools import lru_cache @@ -14,12 +14,14 @@ from pyecsca.ec.formula.partitions import ( expand_all_muls, expand_all_nopower2_muls, ) -from pyecsca.ec.formula.switch_sign import generate_switched_formulas +from pyecsca.ec.formula.switch_sign import generate_switched_formulas, switch_signs +from pyecsca.misc.utils import TaskExecutor def reduce_with_similarity( formulas: List[Formula], norm: Callable[[Formula], Any] ) -> List[Formula]: + formulas = reduce_by_eq(formulas) reduced = list(filter(lambda x: isinstance(x, EFDFormula), formulas)) similarities = list(map(norm, reduced)) for formula in formulas: @@ -51,10 +53,11 @@ def expand_formula_set( - Sign switching - Associativity and Commutativity - :param formulas: - :param norm: - :return: + :param formulas: The set of formulas to expand. + :param norm: The norm to use while reducing. + :return: The expanded set of formulas. """ + @lru_cache(maxsize=1000) def cached(formula): return norm(formula) @@ -62,11 +65,15 @@ def expand_formula_set( extended = sorted(formulas, key=attrgetter("name")) extended = reduce_with_similarity(extended, cached) - fliparood: List[Formula] = reduce_by_eq(sum(map(lambda f: reduce_by_eq(recursive_fliparoo(f)), extended), [])) + fliparood: List[Formula] = reduce_by_eq( + sum(map(lambda f: reduce_by_eq(recursive_fliparoo(f)), extended), []) + ) extended.extend(fliparood) extended = reduce_with_similarity(extended, cached) - switch_signs: List[Formula] = reduce_by_eq(sum(map(lambda f: reduce_by_eq(generate_switched_formulas(f)), extended), [])) + switch_signs: List[Formula] = reduce_by_eq( + sum(map(lambda f: reduce_by_eq(generate_switched_formulas(f)), extended), []) + ) extended.extend(switch_signs) extended = reduce_with_similarity(extended, cached) @@ -78,8 +85,71 @@ def expand_formula_set( extended.extend(mul_expanded) extended = reduce_with_similarity(extended, cached) - np2_expanded: List[Formula] = reduce_by_eq(list(map(expand_all_nopower2_muls, extended))) + np2_expanded: List[Formula] = reduce_by_eq( + list(map(expand_all_nopower2_muls, extended)) + ) extended.extend(np2_expanded) extended = reduce_with_similarity(extended, cached) return set(reduce_by_eq(extended)) + + +@public +def expand_formula_set_parallel( + formulas: Set[Formula], + norm: Callable[[Formula], Any] = ivs_norm, + num_workers: int = 1, +) -> Set[Formula]: + """ + Expand a set of formulas by using transformations (parallelized): + - Fliparoos + - Sign switching + - Associativity and Commutativity + + :param formulas: The set of formulas to expand. + :param norm: The norm to use while reducing. + :param num_workers: The amount of workers to use. + :return: The expanded set of formulas. + """ + + @lru_cache(maxsize=1000) + def cached(formula): + return norm(formula) + + def map_multiple(pool, formulas, fn): + results: List[List[Formula]] = [] + for f in formulas: + pool.submit_task(f, fn, f) + results.append([]) + for f, future in pool.as_completed(): + results[formulas.index(f)] = reduce_by_eq(future.result()) + return reduce_by_eq(sum(results, [])) + + def map_single(pool, formulas, fn): + return reduce_by_eq(pool.map(fn, formulas)) + + extended = sorted(formulas, key=attrgetter("name")) + extended = reduce_with_similarity(extended, cached) + + with TaskExecutor(max_workers=num_workers) as pool: + fliparood = map_multiple(pool, extended, recursive_fliparoo) + extended.extend(fliparood) + extended = reduce_with_similarity(extended, cached) + + switched = map_multiple(pool, extended, switch_signs) + extended.extend(switched) + extended = reduce_with_similarity(extended, cached) + + add_reduced = map_single(pool, extended, reduce_all_adds) + extended.extend(add_reduced) + extended = reduce_with_similarity(extended, cached) + + mul_expanded = map_single(pool, extended, expand_all_muls) + extended.extend(mul_expanded) + extended = reduce_with_similarity(extended, cached) + + np2_expanded = map_single(pool, extended, expand_all_nopower2_muls) + extended.extend(np2_expanded) + extended = reduce_with_similarity(extended, cached) + + return set(reduce_by_eq(extended)) diff --git a/pyecsca/ec/formula/fliparoo.py b/pyecsca/ec/formula/fliparoo.py index fd744b5..4d14f80 100644 --- a/pyecsca/ec/formula/fliparoo.py +++ b/pyecsca/ec/formula/fliparoo.py @@ -1,6 +1,6 @@ """Provides a way to Fliparoo formulas.""" from ast import parse -from typing import Iterator, List, Type, Optional, Set +from typing import Iterator, List, Type, Optional from public import public from pyecsca.ec.op import OpType @@ -254,7 +254,7 @@ def generate_fliparood_formulas( fliparoos = find_fliparoos(graph) for i, fliparoo in enumerate(fliparoos): for j, flip_graph in enumerate(generate_fliparood_graphs(fliparoo)): - yield flip_graph.to_formula(f"fliparoo[{i},{j}]") + yield flip_graph.to_formula(f"fliparoo[{i},{j}]") # noqa def generate_fliparood_graphs(fliparoo: Fliparoo) -> Iterator[FormulaGraph]: diff --git a/pyecsca/ec/formula/switch_sign.py b/pyecsca/ec/formula/switch_sign.py index 6fede9d..95c18be 100644 --- a/pyecsca/ec/formula/switch_sign.py +++ b/pyecsca/ec/formula/switch_sign.py @@ -6,9 +6,9 @@ from itertools import chain, combinations from pyecsca.ec.op import OpType, CodeOp from pyecsca.ec.formula.base import Formula -from pyecsca.ec.formula.graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula, Node +from pyecsca.ec.formula.graph import FormulaGraph, ConstantNode, CodeOpNode, CodeFormula from pyecsca.ec.point import Point -from pyecsca.ec.mod import Mod, mod +from pyecsca.ec.mod import mod @public @@ -21,11 +21,17 @@ def generate_switched_formulas(formula: Formula, rename=True) -> Iterator[CodeFo continue +def switch_signs(formula: Formula, rename=True) -> List[CodeFormula]: + return list(generate_switched_formulas(formula, rename)) + + def subnode_lists(graph: FormulaGraph) -> List[List[CodeOpNode]]: return powerlist(filter(lambda x: x not in graph.roots and x.is_sub, graph.nodes)) -def switch_sign(graph: FormulaGraph, node_combination: List[CodeOpNode]) -> FormulaGraph: +def switch_sign( + graph: FormulaGraph, node_combination: List[CodeOpNode] +) -> FormulaGraph: nodes_i = [graph.node_index(node) for node in node_combination] graph = graph.deepcopy() node_combination = [graph.nodes[node_i] for node_i in nodes_i] # type: ignore diff --git a/test/ec/test_formula.py b/test/ec/test_formula.py index e5fce6e..c0350cb 100644 --- a/test/ec/test_formula.py +++ b/test/ec/test_formula.py @@ -6,7 +6,7 @@ import pytest from sympy import FF, symbols from importlib_resources import files, as_file import pyecsca.ec -from pyecsca.ec.formula.expand import expand_formula_set +from pyecsca.ec.formula.expand import expand_formula_set, expand_formula_set_parallel from pyecsca.ec.formula.fliparoo import generate_fliparood_formulas from pyecsca.ec.formula.graph import rename_ivs from pyecsca.ec.formula.metrics import ( @@ -543,3 +543,10 @@ def test_formula_correctness(library_formula_params): def test_formula_expand(add): res = expand_formula_set({add}) assert len(res) > 1 + + +def test_formula_expand_parallel(add): + res = expand_formula_set_parallel({add}) + assert len(res) > 1 + other = expand_formula_set({add}) + assert res == other |
