aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/sca/trace/trace.py
blob: dbab02d4bb456cf3378a7ee18871bf9589596659 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Provides the Trace class."""
import weakref
from typing import Any, Mapping, Sequence, Optional
from copy import copy, deepcopy

from numpy import ndarray
import numpy as np
from numpy.typing import DTypeLike
from public import public


@public
class Trace:
    """Trace, which has some samples and metadata."""

    meta: Mapping[str, Any]
    samples: ndarray

    def __init__(
        self, samples: ndarray, meta: Optional[Mapping[str, Any]] = None, trace_set: Any = None
    ):
        """
        Construct a new trace.

        :param samples: The sample array of the trace.
        :param meta: Metadata associated with the trace.
        :param trace_set: A trace set the trace is contained in.
        """
        if meta is None:
            meta = {}
        self.meta = meta
        self.samples = samples
        self.trace_set = trace_set

    def __len__(self):
        """Length of the trace, in samples."""
        return len(self.samples)

    def __getitem__(self, index):
        """Get the sample at `index`."""
        return self.samples[index]

    def __setitem__(self, key, value):
        """Set the sample at `key`."""
        self.samples[key] = value

    def __iter__(self):
        """Iterate over the samples."""
        yield from self.samples

    @property
    def trace_set(self) -> Any:
        """Return the trace set this trace is contained in, if any."""
        if self._trace_set is None:
            return None
        return self._trace_set()

    @trace_set.setter
    def trace_set(self, trace_set: Any):
        """Set the trace set of this trace."""
        if trace_set is None:
            self._trace_set = None
        else:
            self._trace_set = weakref.ref(trace_set)

    def __getstate__(self):
        state = self.__dict__.copy()
        state["_trace_set"] = None
        return state

    def __setstate__(self, state):
        self._trace_set = None
        self.__dict__.update(state)

    def __eq__(self, other):
        if not isinstance(other, Trace):
            return False
        return np.array_equal(self.samples, other.samples) and self.meta == other.meta

    def __hash__(self):
        # This will have collisions, but those can be sorted out by the equality check above.
        return hash((str(self.samples), tuple(self.meta.items())))

    def with_samples(self, samples: ndarray) -> "Trace":
        """
        Construct a copy of this trace, with the same metadata, but samples replaced by `samples`.

        :param samples: The samples of the new trace.
        :return: The new trace.
        """
        return Trace(samples, deepcopy(self.meta))

    def astype(self, dtype: DTypeLike) -> "Trace":
        """
        Construct a copy of this trace, with the same samples retyped using `dtype`.

        :param dtype: The numpy dtype.
        :return: The new trace
        """
        return self.with_samples(np.array(self.samples.astype(dtype)))

    def __copy__(self):
        return Trace(copy(self.samples), copy(self.meta), copy(self.trace_set))

    def __deepcopy__(self, memodict):
        return Trace(
            deepcopy(self.samples, memo=memodict) if isinstance(self.samples, np.ndarray) else np.array(self.samples),
            deepcopy(self.meta, memo=memodict)
        )

    def __repr__(self):
        return f"Trace(samples={self.samples!r}, trace_set={self.trace_set!r})"


@public
class CombinedTrace(Trace):
    """Trace that was combined from other traces, :paramref:`~.CombinedTrace.parents`."""

    def __init__(
        self,
        samples: ndarray,
        meta: Optional[Mapping[str, Any]] = None,
        trace_set: Any = None,
        parents: Optional[Sequence[Trace]] = None,
    ):
        super().__init__(samples, meta, trace_set=trace_set)
        self.parents = None
        if parents is not None:
            self.parents = weakref.WeakSet(parents)

    def __repr__(self):
        return f"CombinedTrace(samples={self.samples!r}, trace_set={self.trace_set!r}, parents={self.parents})"