aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/formula/unroll.py
blob: 3c63e00034ebac4f60a5a0b7422edd59cee74655 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Provides functions for unrolling formula intermediate values symvolically."""

from typing import List, Tuple

from astunparse import unparse
from public import public
from sympy import Expr, symbols, Poly

from pyecsca.misc.cache import sympify
from pyecsca.ec.formula.base import Formula


@public
def unroll_formula_expr(formula: Formula) -> List[Tuple[str, Expr]]:
    """
    Unroll a given formula symbolically to obtain symbolic expressions for its intermediate values.

    :param formula: Formula to unroll.
    :return: List of symbolic intermediate values, with associated variable names.
    """
    params = {
        var: symbols(var)
        for var in formula.coordinate_model.curve_model.parameter_names
    }
    inputs = {
        f"{var}{i}": symbols(f"{var}{i}")
        for var in formula.coordinate_model.variables
        for i in range(1, formula.num_inputs + 1)
    }
    for coord_assumption in formula.coordinate_model.assumptions:
        assumption_string = unparse(coord_assumption).strip()
        lhs, rhs = assumption_string.split(" = ")
        if lhs in params:
            expr = sympify(rhs, evaluate=False)
            params[lhs] = expr
    for assumption_string in formula.assumptions_str:
        lhs, rhs = assumption_string.split(" == ")
        if lhs in formula.parameters:
            # Handle a symbolic assignment to a new parameter.
            expr = sympify(rhs, evaluate=False)
            for curve_param, value in params.items():
                expr = expr.xreplace({curve_param: value})
            params[lhs] = expr

    locls = {**params, **inputs}
    values: List[Tuple[str, Expr]] = []
    for op in formula.code:
        result: Expr = op(**locls)  # type: ignore
        locls[op.result] = result
        values.append((op.result, result))
    return values


@public
def unroll_formula(formula: Formula) -> List[Tuple[str, Poly]]:
    """
    Unroll a given formula symbolically to obtain symbolic expressions (as Polynomials) for its intermediate values.

    :param formula: Formula to unroll.
    :return: List of symbolic intermediate values, with associated variable names.
    """
    values = unroll_formula_expr(formula)
    polys = []
    for name, result in values:
        if result.free_symbols:
            gens = list(result.free_symbols)
            gens.sort(key=str)
            poly = Poly(result, *gens)
            polys.append((name, poly))
        else:
            # TODO: We cannot create a Poly here, because the result does not have free symbols (i.e. it is a constant)
            pass
    return polys