aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/utils/label_studio_utils.py
blob: 82d663bee757f34b07bee1fc8fbb5ee330481055 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import json
import logging
from pathlib import Path

from tqdm import tqdm

from sec_certs.configuration import config
from sec_certs.dataset.auxiliary_dataset_handling import CPEDatasetHandler
from sec_certs.dataset.dataset import Dataset
from sec_certs.sample.cpe import CPE

logger = logging.getLogger(__name__)


def to_label_studio_json(dataset: Dataset, output_path: str | Path) -> None:
    dataset.load_auxiliary_datasets()
    cpe_dset = dataset.aux_handlers[CPEDatasetHandler].dset

    lst = []
    for cert in [x for x in dataset if x.heuristics.cpe_matches]:
        dct = {"text": cert.label_studio_title}
        candidates = [cpe_dset[x].title for x in cert.heuristics.cpe_matches]
        candidates += ["No good match"] * (config.cpe_n_max_matches - len(candidates))
        options = ["option_" + str(x) for x in range(1, config.cpe_n_max_matches)]
        dct.update(dict(zip(options, candidates)))
        lst.append(dct)

    with Path(output_path).open("w") as handle:
        json.dump(lst, handle, indent=4)


def load_label_studio_labels(dataset: Dataset, input_path: str | Path) -> set[str]:
    with Path(input_path).open("r") as handle:
        data = json.load(handle)

    dataset.load_auxiliary_datasets()
    cpe_dset = dataset.aux_handlers[CPEDatasetHandler].dset
    title_to_cpes_dict = cpe_dset.get_title_to_cpes_dict()
    labeled_cert_digests: set[str] = set()

    logger.info("Translating label studio matches into their CPE representations and assigning to certificates.")
    for annotation in tqdm(data, desc="Translating label studio matches"):
        cpe_candidate_keys = {key for key in annotation if "option_" in key and annotation[key] != "No good match"}

        if "verified_cpe_match" not in annotation:
            incorrect_keys: set[str] = set()
        elif isinstance(annotation["verified_cpe_match"], str):
            incorrect_keys = {annotation["verified_cpe_match"]}
        else:
            incorrect_keys = set(annotation["verified_cpe_match"]["choices"])

        incorrect_keys = {x.lstrip("$") for x in incorrect_keys}
        predicted_annotations = {annotation[x] for x in cpe_candidate_keys - incorrect_keys}

        cpes: set[CPE] = set()
        for x in predicted_annotations:
            if x not in title_to_cpes_dict:
                logger.error(f"{x} not in dataset")
            else:
                to_update = title_to_cpes_dict[x]
                if to_update and not cpes:
                    cpes = to_update
                elif to_update and cpes:
                    cpes.update(to_update)

        # distinguish between FIPS and CC
        if "\n" in annotation["text"]:
            cert_name = annotation["text"].split("\nModule name: ")[1].split("\n")[0]
        else:
            cert_name = annotation["text"]

        certs = dataset.get_certs_by_name(cert_name)
        labeled_cert_digests.update({x.dgst for x in certs})

        for c in certs:
            c.heuristics.verified_cpe_matches = {x.uri for x in cpes if x is not None} if cpes else None

    return labeled_cert_digests