aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/model/evaluation.py
blob: 6f14fda5e1b8e59f19e509639c141385487b6a7a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations

import json
import logging
from pathlib import Path

import numpy as np

from sec_certs.dataset.cpe import CPEDataset
from sec_certs.sample.cc import CCCertificate
from sec_certs.sample.fips import FIPSCertificate
from sec_certs.serialization.json import CustomJSONEncoder
from sec_certs.utils import helpers

logger = logging.getLogger(__name__)


def get_validation_dgsts(filepath: str | Path) -> set[str]:
    with Path(filepath).open("r") as handle:
        return set(json.load(handle))


def compute_precision(y: np.ndarray, y_pred: np.ndarray, **kwargs) -> float:
    prec = []
    for true, pred in zip(y, y_pred):
        set_pred = set(pred) if pred else set()
        set_true = set(true) if true else set()
        if not set_pred and not set_true:
            pass
        else:
            prec.append(len(set_true) / len(set_true.union(set_pred)))
    return np.mean(prec)  # type: ignore


def evaluate(
    x_valid: list[CCCertificate | FIPSCertificate],
    y_valid: list[set[str] | None],
    outpath: Path | str | None,
    cpe_dset: CPEDataset,
) -> None:
    y_pred = [x.heuristics.cpe_matches for x in x_valid]
    precision = compute_precision(np.array(y_valid), np.array(y_pred))

    correctly_classified = []
    badly_classified = []
    n_cpes_lost = 0
    n_certs_has_lost_cpes = 0
    n_certs_lost = 0

    for cert, predicted_cpes, verified_cpes in zip(x_valid, y_pred, y_valid):
        if not verified_cpes:
            verified_cpes = set()
        verified_cpes_dict = {x: cpe_dset[x].title if cpe_dset[x].title else x for x in verified_cpes}

        if not predicted_cpes:
            predicted_cpes = set()
        predicted_cpes_dict = {x: cpe_dset[x].title if cpe_dset[x].title else x for x in predicted_cpes}

        cert_name = cert.name if isinstance(cert, CCCertificate) else cert.web_data.module_name
        vendor = cert.manufacturer if isinstance(cert, CCCertificate) else cert.web_data.vendor

        should_be_removed = {x: cpe_dset[x].title if cpe_dset[x].title else x for x in predicted_cpes - verified_cpes}
        should_be_added = {x: cpe_dset[x].title if cpe_dset[x].title else x for x in verified_cpes - predicted_cpes}

        record = {
            "certificate_name": cert_name,
            "vendor": vendor,
            "heuristic version": helpers.compute_heuristics_version(cert_name) if cert_name else None,
            "predicted_cpes": predicted_cpes_dict,
            "manually_assigned_cpes": verified_cpes_dict,
            "should_be_removed": should_be_removed,
            "should_be_added": should_be_added,
        }

        if not predicted_cpes and verified_cpes:
            n_certs_lost += 1

        if predicted_cpes.issubset(verified_cpes):
            correctly_classified.append(record)
        else:
            badly_classified.append(record)

        if should_be_added:
            n_certs_has_lost_cpes += 1
            n_cpes_lost += len(should_be_added)

    results = {
        "Precision": precision,
        "n_cpes_lost": n_cpes_lost,
        "n_certs_has_lost_cpes": n_certs_has_lost_cpes,
        "n_certs_lost": n_certs_lost,
        "correctly_classified": correctly_classified,
        "badly_classified": badly_classified,
    }
    logger.info(
        f"Precision: {precision}; the classifier now misses {n_cpes_lost} CPE matches from {n_certs_has_lost_cpes} certificates ({n_certs_lost} certificates no longer have a match). In total, {sum([len(x) for x in y_pred if x])} CPEs were matched in {len(y_pred)} certs."
    )

    if outpath:
        with Path(outpath).open("w") as handle:
            json.dump(results, handle, indent=4, cls=CustomJSONEncoder, sort_keys=True)