aboutsummaryrefslogtreecommitdiff
path: root/pyecsca/sca/trace/match.py
blob: e1181e98114a3dd6bf425ba62d1c4138630c56b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
This module provides functions for matching a pattern within a trace to it.
"""
import numpy as np
from scipy.signal import find_peaks
from public import public
from typing import List

from .process import normalize
from .edit import trim
from .trace import Trace


@public
def match_pattern(trace: Trace, pattern: Trace, threshold: float = 0.8) -> List[int]:
    """
    Match a `pattern` to a `trace`. Returns indices where the pattern matches, e.g. those where correlation
    of the two traces has peaks larger than `threshold`. Uses the :py:func:`scipy.signal.find_peaks` function.

    :param trace: The trace to match into.
    :param pattern: The pattern to match.
    :param threshold: The threshold passed to :py:func:`scipy.signal.find_peaks` as a `prominence` value.
    :return: Indices where the pattern matches.
    """
    normalized = normalize(trace)
    pattern_samples = normalize(pattern).samples
    correlation = np.correlate(normalized.samples, pattern_samples, "same")
    correlation = (correlation - np.mean(correlation)) / (np.max(correlation))
    peaks, props = find_peaks(correlation, prominence=(threshold, None))
    pairs = sorted(zip(peaks, props["prominences"]), key=lambda it: it[1], reverse=True)
    half = len(pattern_samples) // 2
    filtered_peaks: List[int] = []
    for peak, _ in pairs:
        if not filtered_peaks:
            filtered_peaks.append(peak - half)
        else:
            for other_peak in filtered_peaks:
                if abs((peak - half) - other_peak) <= len(pattern_samples):
                    break
            else:
                filtered_peaks.append(peak - half)
    return filtered_peaks


@public
def match_part(
    trace: Trace, offset: int, length: int, threshold: float = 0.8
) -> List[int]:
    """
    Match a part of a `trace` starting at `offset` of `length` to the `trace`. Returns indices where the pattern matches
    , e.g. those where correlation of the two traces has peaks larger than `threshold`. Uses the
    :py:func:`scipy.signal.find_peaks` function.

    :param trace: The trace to match into.
    :param offset: The start of the pattern in the trace to match.
    :param length: The length of the pattern in the trace to match.
    :param threshold: The threshold passed to :py:func:`scipy.signal.find_peaks` as a `prominence` value.
    :return: Indices where the part of the trace matches matches.
    """
    return match_pattern(trace, trim(trace, offset, offset + length), threshold)