aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/model/matching.py
blob: 2ba48f647b160609c035b5fa08faecbf8a99dd42 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from heapq import heappop, heappush
from typing import Any, Generic, TypeVar

from rapidfuzz import fuzz

from sec_certs.sample.certificate import Certificate

CertSubType = TypeVar("CertSubType", bound=Certificate)


class AbstractMatcher(Generic[CertSubType], ABC):
    entry: Any

    @abstractmethod
    def match(self, cert: CertSubType) -> float:
        raise NotImplementedError

    def _compute_match(self, one: str, other: str) -> float:
        return max(
            [
                fuzz.token_set_ratio(one, other),
                fuzz.partial_token_sort_ratio(one, other, score_cutoff=100),
                fuzz.partial_ratio(one, other, score_cutoff=100),
            ]
        )

    @staticmethod
    def _match_certs(
        matchers: Sequence[AbstractMatcher], certs: list[CertSubType], threshold: float
    ) -> tuple[dict[str, Any], dict[str, float]]:
        scores: list[tuple[float, int, int]] = []
        matched_is: set[int] = set()
        matched_js: set[int] = set()
        for i, cert in enumerate(certs):
            for j, matcher in enumerate(matchers):
                score = matcher.match(cert)
                triple = (100 - score, i, j)
                heappush(scores, triple)
        results = {}
        final_scores = {}
        for triple in (heappop(scores) for _ in range(len(scores))):
            inv_score, i, j = triple
            # Do not match already matched entries/certs.
            if i in matched_is or j in matched_js:
                continue
            # Compute the actual score from the inverse.
            score = 100 - inv_score
            # Do not match if we are below threshold, all the following will be as well.
            if score < threshold:
                break
            # Match cert dgst to entry
            matched_is.add(i)
            matched_js.add(j)
            cert = certs[i]
            entry = matchers[j].entry
            results[cert.dgst] = entry
            final_scores[cert.dgst] = score
        return results, final_scores