from __future__ import annotations import logging from collections import Counter from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum from typing import TypeVar from sec_certs.sample.certificate import Certificate, References from sec_certs.serialization.json import ComplexSerializableType CertSubType = TypeVar("CertSubType", bound=Certificate) IDLookupFunc = Callable[[CertSubType], str] ReferenceLookupFunc = Callable[[CertSubType], References] class ReferenceType(Enum): DIRECT = "direct" INDIRECT = "indirect" @dataclass class TransitiveCVEs(ComplexSerializableType): direct_transitive_cves: set[str] | None = field(default=None) indirect_transitive_cves: set[str] | None = field(default=None) Certificates = dict[str, CertSubType] Vulnerabilities = dict[str, dict[str, set[str] | None]] class TransitiveVulnerabilityFinder: """ The class assigns vulnerabilities to each certificate instance caused by references among certificate instances. Adheres to sklearn BaseEstimator interface. """ def __init__(self, id_func: IDLookupFunc): self.vulnerabilities: Vulnerabilities = {} self.certificates: Certificates = {} self._cert_id_counter: Counter = Counter() self._fitted = False self._id_func: IDLookupFunc = id_func def _clear_state(self) -> None: self.vulnerabilities = {} self.certificates = {} self._cert_id_counter = Counter() def _fill_dataset_cert_ids_counter(self) -> None: self._cert_id_counter = Counter([self._id_func(x) for x in self.certificates.values()]) def _get_cert_transitive_cves( self, cert: CertSubType, reference_type: ReferenceType, ref_func: ReferenceLookupFunc ) -> set[str] | None: references = ( ref_func(cert).directly_referenced_by if reference_type == ReferenceType.DIRECT else ref_func(cert).indirectly_referenced_by ) if not references: return None vulnerabilities = set() for cert_id in references: if self._cert_id_counter[cert_id] != 1: continue for cert in self.certificates.values(): cves = cert.heuristics.related_cves if self._id_func(cert) == cert_id and cves: vulnerabilities.update(cves) return vulnerabilities if vulnerabilities else None def fit(self, certificates: Certificates, ref_func: ReferenceLookupFunc) -> Vulnerabilities: """ Method assigns each certificate vulnerabilities caused by references among certificates :param Certificates certificates: Dictionary of certificates with digests :return Vulnerabilities: Dictionary of vulnerabilities of certificate instances """ self._clear_state() self.certificates = certificates self._fill_dataset_cert_ids_counter() thrown_away_cert_counter = 0 for cert in self.certificates.values(): cert_id = self._id_func(cert) if not cert_id: continue if self._cert_id_counter[cert_id] > 1: thrown_away_cert_counter += 1 continue self.vulnerabilities[cert.dgst] = {} self.vulnerabilities[cert.dgst][ReferenceType.DIRECT.value] = self._get_cert_transitive_cves( cert, ReferenceType.DIRECT, ref_func ) self.vulnerabilities[cert.dgst][ReferenceType.INDIRECT.value] = self._get_cert_transitive_cves( cert, ReferenceType.INDIRECT, ref_func ) if thrown_away_cert_counter > 0: logging.warning("There were total of %s certificates skipped due to duplicity", thrown_away_cert_counter) self._fitted = True return self.vulnerabilities def predict_single_cert(self, dgst: str) -> TransitiveCVEs: """ Method returns vulnerabilities for certificate digest :param str dgst: Digest of certificate :return TransitiveCVE: TransitiveCVE object of certificate """ if not self._fitted: raise ValueError("Finder not yet fitted") if not self.vulnerabilities.get(dgst): return TransitiveCVEs(direct_transitive_cves=None, indirect_transitive_cves=None) return TransitiveCVEs( self.vulnerabilities[dgst][ReferenceType.DIRECT.value], self.vulnerabilities[dgst][ReferenceType.INDIRECT.value], ) def predict(self, dgst_list: list[str]) -> dict[str, TransitiveCVEs]: """ Method returns vulnerabilities for a list of certificate digests :param List[str] dgst_list: list of certificate digests :return Dict[str, TransitiveCVE]: Dictionary of TransitiveCVE objects for specified certificate digests """ if not self._fitted: raise ValueError("Finder not yet fitted") cert_vulnerabilities = {} for dgst in dgst_list: cert_vulnerabilities[dgst] = self.predict_single_cert(dgst) return cert_vulnerabilities