aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/ec/mod/symbolic.py
blob: c3d006be8d359a6a3d60eae93644e67b688462a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from functools import wraps

from public import public
from sympy import Expr

from pyecsca.ec.mod.base import Mod


def _check(func):
    @wraps(func)
    def method(self, other):
        if self.__class__ is not type(other):
            other = self.__class__(other, self.n)
        elif self.n != other.n:
            raise ValueError
        return func(self, other)

    return method


@public
class SymbolicMod(Mod):
    """A symbolic element x of ℤₙ (implemented using sympy)."""

    x: Expr
    n: int
    __slots__ = ("x", "n")

    def __init__(self, x: Expr, n: int):
        self.x = x
        self.n = n

    @_check
    def __add__(self, other) -> "SymbolicMod":
        return self.__class__((self.x + other.x), self.n)

    @_check
    def __radd__(self, other) -> "SymbolicMod":
        return self + other

    @_check
    def __sub__(self, other) -> "SymbolicMod":
        return self.__class__((self.x - other.x), self.n)

    @_check
    def __rsub__(self, other) -> "SymbolicMod":
        return -self + other

    def __neg__(self) -> "SymbolicMod":
        return self.__class__(-self.x, self.n)

    def bit_length(self) -> int:
        raise NotImplementedError

    def inverse(self) -> "SymbolicMod":
        return self.__class__(self.x ** (-1), self.n)

    def is_residue(self) -> bool:
        raise NotImplementedError

    def sqrt(self) -> "SymbolicMod":
        raise NotImplementedError

    def is_cubic_residue(self) -> bool:
        raise NotImplementedError

    def cube_root(self) -> "SymbolicMod":
        raise NotImplementedError

    def __invert__(self) -> "SymbolicMod":
        return self.inverse()

    @_check
    def __mul__(self, other) -> "SymbolicMod":
        return self.__class__(self.x * other.x, self.n)

    @_check
    def __rmul__(self, other) -> "SymbolicMod":
        return self * other

    @_check
    def __truediv__(self, other) -> "SymbolicMod":
        return self * ~other

    @_check
    def __rtruediv__(self, other) -> "SymbolicMod":
        return ~self * other

    @_check
    def __floordiv__(self, other) -> "SymbolicMod":
        return self * ~other

    @_check
    def __rfloordiv__(self, other) -> "SymbolicMod":
        return ~self * other

    def __bytes__(self):
        return int(self.x).to_bytes((self.n.bit_length() + 7) // 8, byteorder="big")

    def __int__(self):
        return int(self.x)

    def __eq__(self, other):
        if type(other) is int:
            return self.x == other % self.n
        if type(other) is not SymbolicMod:
            return False
        return self.x == other.x and self.n == other.n

    def __ne__(self, other):
        return not self == other

    def __repr__(self):
        return str(self.x)

    def __hash__(self):
        return hash(("SymbolicMod", self.x, self.n))

    def __pow__(self, n, _=None) -> "SymbolicMod":
        return self.__class__(pow(self.x, n), self.n)


from pyecsca.ec.mod.base import _mod_classes  # noqa

_mod_classes["symbolic"] = SymbolicMod