aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/model/transitive_vulnerability_finder.py
blob: 93c1496eb9b988272427006bd9b6c792b061ff45 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

import logging
from collections import Counter
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TypeVar

from sec_certs.sample.certificate import Certificate, References
from sec_certs.serialization.json import ComplexSerializableType

CertSubType = TypeVar("CertSubType", bound=Certificate)
IDLookupFunc = Callable[[CertSubType], str]
ReferenceLookupFunc = Callable[[CertSubType], References]


class ReferenceType(Enum):
    DIRECT = "direct"
    INDIRECT = "indirect"


@dataclass
class TransitiveCVEs(ComplexSerializableType):
    direct_transitive_cves: set[str] | None = field(default=None)
    indirect_transitive_cves: set[str] | None = field(default=None)


Certificates = dict[str, CertSubType]
Vulnerabilities = dict[str, dict[str, set[str] | None]]


class TransitiveVulnerabilityFinder:
    """
    The class assigns vulnerabilities to each certificate instance caused by references among certificate instances.
    Adheres to sklearn BaseEstimator interface.
    """

    def __init__(self, id_func: IDLookupFunc):
        self.vulnerabilities: Vulnerabilities = {}
        self.certificates: Certificates = {}
        self._cert_id_counter: Counter = Counter()
        self._fitted = False
        self._id_func: IDLookupFunc = id_func

    def _clear_state(self) -> None:
        self.vulnerabilities = {}
        self.certificates = {}
        self._cert_id_counter = Counter()

    def _fill_dataset_cert_ids_counter(self) -> None:
        self._cert_id_counter = Counter([self._id_func(x) for x in self.certificates.values()])

    def _get_cert_transitive_cves(
        self, cert: CertSubType, reference_type: ReferenceType, ref_func: ReferenceLookupFunc
    ) -> set[str] | None:
        references = (
            ref_func(cert).directly_referenced_by
            if reference_type == ReferenceType.DIRECT
            else ref_func(cert).indirectly_referenced_by
        )

        if not references:
            return None

        vulnerabilities = set()

        for cert_id in references:
            if self._cert_id_counter[cert_id] != 1:
                continue

            for cert in self.certificates.values():
                cves = cert.heuristics.related_cves
                if self._id_func(cert) == cert_id and cves:
                    vulnerabilities.update(cves)

        return vulnerabilities if vulnerabilities else None

    def fit(self, certificates: Certificates, ref_func: ReferenceLookupFunc) -> Vulnerabilities:
        """
        Method assigns each certificate vulnerabilities caused by references among certificates

        :param Certificates certificates: Dictionary of certificates with digests
        :return Vulnerabilities: Dictionary of vulnerabilities of certificate instances
        """
        self._clear_state()
        self.certificates = certificates
        self._fill_dataset_cert_ids_counter()

        thrown_away_cert_counter = 0

        for cert in self.certificates.values():
            cert_id = self._id_func(cert)

            if not cert_id:
                continue
            if self._cert_id_counter[cert_id] > 1:
                thrown_away_cert_counter += 1
                continue

            self.vulnerabilities[cert.dgst] = {}
            self.vulnerabilities[cert.dgst][ReferenceType.DIRECT.value] = self._get_cert_transitive_cves(
                cert, ReferenceType.DIRECT, ref_func
            )
            self.vulnerabilities[cert.dgst][ReferenceType.INDIRECT.value] = self._get_cert_transitive_cves(
                cert, ReferenceType.INDIRECT, ref_func
            )

        if thrown_away_cert_counter > 0:
            logging.warning("There were total of %s certificates skipped due to duplicity", thrown_away_cert_counter)

        self._fitted = True

        return self.vulnerabilities

    def predict_single_cert(self, dgst: str) -> TransitiveCVEs:
        """
        Method returns vulnerabilities for certificate digest

        :param str dgst: Digest of certificate
        :return TransitiveCVE: TransitiveCVE object of certificate
        """
        if not self._fitted:
            raise ValueError("Finder not yet fitted")

        if not self.vulnerabilities.get(dgst):
            return TransitiveCVEs(direct_transitive_cves=None, indirect_transitive_cves=None)

        return TransitiveCVEs(
            self.vulnerabilities[dgst][ReferenceType.DIRECT.value],
            self.vulnerabilities[dgst][ReferenceType.INDIRECT.value],
        )

    def predict(self, dgst_list: list[str]) -> dict[str, TransitiveCVEs]:
        """
        Method returns vulnerabilities for a list of certificate digests

        :param List[str] dgst_list: list of certificate digests
        :return Dict[str, TransitiveCVE]: Dictionary of TransitiveCVE objects for specified certificate digests
        """
        if not self._fitted:
            raise ValueError("Finder not yet fitted")

        cert_vulnerabilities = {}

        for dgst in dgst_list:
            cert_vulnerabilities[dgst] = self.predict_single_cert(dgst)

        return cert_vulnerabilities