aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/dataset/fips_mip.py
blob: 41ab464641e326d0c2c017d2021f1742b51e7707 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import dataclass
from datetime import date
from operator import attrgetter
from pathlib import Path
from tempfile import NamedTemporaryFile

import requests

from sec_certs.configuration import config
from sec_certs.dataset.dataset import logger
from sec_certs.dataset.json_path_dataset import JSONPathDataset
from sec_certs.sample.fips_mip import MIPFlow, MIPSnapshot, MIPStatus
from sec_certs.serialization.json import ComplexSerializableType
from sec_certs.utils.tqdm import tqdm


@dataclass
class MIPDataset(JSONPathDataset, ComplexSerializableType):
    snapshots: list[MIPSnapshot]

    def __init__(self, snapshots: list[MIPSnapshot], json_path: str | Path | None = None):
        super().__init__(json_path)
        self.snapshots = snapshots

    def __iter__(self) -> Iterator[MIPSnapshot]:
        yield from self.snapshots

    def __getitem__(self, item: int) -> MIPSnapshot:
        return self.snapshots.__getitem__(item)

    def __len__(self) -> int:
        return len(self.snapshots)

    @classmethod
    def from_dumps(cls, dump_path: str | Path) -> MIPDataset:
        directory = Path(dump_path)
        fnames = list(directory.glob("*"))
        snapshots = []
        for dump_path in tqdm(sorted(fnames), total=len(fnames)):
            try:
                snapshots.append(MIPSnapshot.from_dump(dump_path))
            except Exception as e:
                logger.error(e)
        return cls(snapshots)

    def to_dict(self) -> dict[str, list[MIPSnapshot]]:
        return {"snapshots": list(self.snapshots)}

    @classmethod
    def from_dict(cls, dct: Mapping) -> MIPDataset:
        return cls(dct["snapshots"])

    @classmethod
    def from_web(cls) -> MIPDataset:
        """
        Get the MIPDataset from sec-certs.org
        """
        mip_resp = requests.get(config.fips_mip_dataset)
        if mip_resp.status_code != requests.codes.ok:
            raise ValueError(f"Getting MIP dataset failed: {mip_resp.status_code}")
        with NamedTemporaryFile(suffix=".json") as tmpfile:
            tmpfile.write(mip_resp.content)
            return cls.from_json(tmpfile.name)

    def compute_flows(self) -> list[MIPFlow]:
        """
        Compute the MIPFlows, deduplicating the MIPEntries in the snapshots
        and computing their state-changes.

        :return: The MIPFlows.
        """
        flows: dict[tuple[str, str, str], list[tuple[date, MIPStatus]]] = {}
        for snapshot in sorted(self.snapshots, key=attrgetter("timestamp")):
            snapshot_date = snapshot.timestamp.date()
            entries: dict[tuple[str, str, str], set] = {}
            for entry in snapshot:
                key = (entry.module_name, entry.vendor_name, entry.standard)
                s = entries.setdefault(key, set())
                s.add(entry)

            for key, dups in entries.items():
                if len(dups) > 1:
                    logger.warning(f"Duplicate MIPEntry when computing MIPFlow, {key}.")
                entry = sorted(dups)[0]
                entry_flows = flows.setdefault(key, [])
                if entry_flows:
                    last_state_change = entry_flows[-1]
                    last_date, last_status = last_state_change
                    if last_status != entry.status:
                        entry_flows.append((snapshot_date, entry.status))
                    else:
                        entry_flows[-1] = (snapshot_date, entry.status)
                else:
                    entry_flows.append((snapshot_date, entry.status))

        return [MIPFlow(*key, value) for key, value in flows.items()]