1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
|
from __future__ import annotations
import logging
from collections import Counter
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TypeVar
from sec_certs.sample.certificate import Certificate, References
from sec_certs.serialization.json import ComplexSerializableType
CertSubType = TypeVar("CertSubType", bound=Certificate)
IDLookupFunc = Callable[[CertSubType], str]
ReferenceLookupFunc = Callable[[CertSubType], References]
class ReferenceType(Enum):
DIRECT = "direct"
INDIRECT = "indirect"
@dataclass
class TransitiveCVEs(ComplexSerializableType):
direct_transitive_cves: set[str] | None = field(default=None)
indirect_transitive_cves: set[str] | None = field(default=None)
Certificates = dict[str, CertSubType]
Vulnerabilities = dict[str, dict[str, set[str] | None]]
class TransitiveVulnerabilityFinder:
"""
The class assigns vulnerabilities to each certificate instance caused by references among certificate instances.
Adheres to sklearn BaseEstimator interface.
"""
def __init__(self, id_func: IDLookupFunc):
self.vulnerabilities: Vulnerabilities = {}
self.certificates: Certificates = {}
self._cert_id_counter: Counter = Counter()
self._fitted = False
self._id_func: IDLookupFunc = id_func
def _clear_state(self) -> None:
self.vulnerabilities = {}
self.certificates = {}
self._cert_id_counter = Counter()
def _fill_dataset_cert_ids_counter(self) -> None:
self._cert_id_counter = Counter([self._id_func(x) for x in self.certificates.values()])
def _get_cert_transitive_cves(
self, cert: CertSubType, reference_type: ReferenceType, ref_func: ReferenceLookupFunc
) -> set[str] | None:
references = (
ref_func(cert).directly_referenced_by
if reference_type == ReferenceType.DIRECT
else ref_func(cert).indirectly_referenced_by
)
if not references:
return None
vulnerabilities = set()
for cert_id in references:
if self._cert_id_counter[cert_id] != 1:
continue
for cert in self.certificates.values():
cves = cert.heuristics.related_cves
if self._id_func(cert) == cert_id and cves:
vulnerabilities.update(cves)
return vulnerabilities if vulnerabilities else None
def fit(self, certificates: Certificates, ref_func: ReferenceLookupFunc) -> Vulnerabilities:
"""
Method assigns each certificate vulnerabilities caused by references among certificates
:param Certificates certificates: Dictionary of certificates with digests
:return Vulnerabilities: Dictionary of vulnerabilities of certificate instances
"""
self._clear_state()
self.certificates = certificates
self._fill_dataset_cert_ids_counter()
thrown_away_cert_counter = 0
for cert in self.certificates.values():
cert_id = self._id_func(cert)
if not cert_id:
continue
if self._cert_id_counter[cert_id] > 1:
thrown_away_cert_counter += 1
continue
self.vulnerabilities[cert.dgst] = {}
self.vulnerabilities[cert.dgst][ReferenceType.DIRECT.value] = self._get_cert_transitive_cves(
cert, ReferenceType.DIRECT, ref_func
)
self.vulnerabilities[cert.dgst][ReferenceType.INDIRECT.value] = self._get_cert_transitive_cves(
cert, ReferenceType.INDIRECT, ref_func
)
if thrown_away_cert_counter > 0:
logging.warning("There were total of %s certificates skipped due to duplicity", thrown_away_cert_counter)
self._fitted = True
return self.vulnerabilities
def predict_single_cert(self, dgst: str) -> TransitiveCVEs:
"""
Method returns vulnerabilities for certificate digest
:param str dgst: Digest of certificate
:return TransitiveCVE: TransitiveCVE object of certificate
"""
if not self._fitted:
raise ValueError("Finder not yet fitted")
if not self.vulnerabilities.get(dgst):
return TransitiveCVEs(direct_transitive_cves=None, indirect_transitive_cves=None)
return TransitiveCVEs(
self.vulnerabilities[dgst][ReferenceType.DIRECT.value],
self.vulnerabilities[dgst][ReferenceType.INDIRECT.value],
)
def predict(self, dgst_list: list[str]) -> dict[str, TransitiveCVEs]:
"""
Method returns vulnerabilities for a list of certificate digests
:param List[str] dgst_list: list of certificate digests
:return Dict[str, TransitiveCVE]: Dictionary of TransitiveCVE objects for specified certificate digests
"""
if not self._fitted:
raise ValueError("Finder not yet fitted")
cert_vulnerabilities = {}
for dgst in dgst_list:
cert_vulnerabilities[dgst] = self.predict_single_cert(dgst)
return cert_vulnerabilities
|