aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/sca/attack/leakage_model.py
blob: f9adcff4c1def18c00cd99ff2668c1bd96c7cc3a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import abc
import sys
from typing import Literal, ClassVar

from numpy.random import default_rng
from public import public

from ...sca.trace import Trace

if sys.version_info[0] < 3 or sys.version_info[0] == 3 and sys.version_info[1] < 10:
    def hw(i):
        return bin(i).count("1")
else:
    def hw(i):
        return i.bit_count()


@public
class Noise:
    pass


@public
class ZeroNoise(Noise):
    def __call__(self, *args, **kwargs):
        return args[0]


@public
class NormalNoice(Noise):
    """
    https://www.youtube.com/watch?v=SAfq55aiqPc
    """

    def __init__(self, mean: float, sdev: float):
        self.rng = default_rng()
        self.mean = mean
        self.sdev = sdev

    def __call__(self, *args, **kwargs):
        arg = args[0]
        if isinstance(arg, Trace):
            return Trace(arg.samples + self.rng.normal(self.mean, self.sdev, len(arg.samples)))
        return arg + self.rng.normal(self.mean, self.sdev)


@public
class LeakageModel(abc.ABC):
    num_args: ClassVar[int]

    @abc.abstractmethod
    def __call__(self, *args, **kwargs) -> int:
        raise NotImplementedError


@public
class Identity(LeakageModel):
    num_args = 1

    def __call__(self, *args, **kwargs) -> int:
        return int(args[0])


@public
class Bit(LeakageModel):
    num_args = 1

    def __init__(self, which: int):
        if which < 0:
            raise ValueError("which must be >= 0.")
        self.which = which
        self.mask = 1 << which

    def __call__(self, *args, **kwargs) -> Literal[0, 1]:
        return (int(args[0]) & self.mask) >> self.which  # type: ignore


@public
class Slice(LeakageModel):
    num_args = 1

    def __init__(self, begin: int, end: int):
        if begin > end:
            raise ValueError("begin must be <= than end.")
        self.begin = begin
        self.end = end
        self.mask = 0
        for i in range(begin, end):
            self.mask |= 1 << i

    def __call__(self, *args, **kwargs) -> int:
        return (int(args[0]) & self.mask) >> self.begin


@public
class HammingWeight(LeakageModel):
    num_args = 1

    def __call__(self, *args, **kwargs) -> int:
        return hw(int(args[0]))


@public
class HammingDistance(LeakageModel):
    num_args = 2

    def __call__(self, *args, **kwargs) -> int:
        return hw(int(args[0]) ^ int(args[1]))


@public
class BitLength(LeakageModel):
    num_args = 1

    def __call__(self, *args, **kwargs) -> int:
        return int(args[0]).bit_length()