aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/dataset/fips_iut.py
blob: 45986c11f7623886ab048054695a28eb41cedbe9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile

import requests

from sec_certs.configuration import config
from sec_certs.dataset.dataset import logger
from sec_certs.dataset.json_path_dataset import JSONPathDataset
from sec_certs.sample.fips_iut import IUTSnapshot
from sec_certs.serialization.json import ComplexSerializableType
from sec_certs.utils.tqdm import tqdm


@dataclass
class IUTDataset(JSONPathDataset, ComplexSerializableType):
    snapshots: list[IUTSnapshot]

    def __init__(self, snapshots: list[IUTSnapshot], json_path: str | Path | None = None):
        super().__init__(json_path)
        self.snapshots = snapshots

    def __iter__(self) -> Iterator[IUTSnapshot]:
        yield from self.snapshots

    def __getitem__(self, item: int) -> IUTSnapshot:
        return self.snapshots.__getitem__(item)

    def __len__(self) -> int:
        return len(self.snapshots)

    @classmethod
    def from_dumps(cls, dump_path: str | Path) -> IUTDataset:
        directory = Path(dump_path)
        fnames = list(directory.glob("*"))
        snapshots = []
        for dump_path in tqdm(sorted(fnames), total=len(fnames)):
            try:
                snapshots.append(IUTSnapshot.from_dump(dump_path))
            except Exception as e:
                logger.error(e)
        return cls(snapshots)

    def to_dict(self) -> dict[str, list[IUTSnapshot]]:
        return {"snapshots": list(self.snapshots)}

    @classmethod
    def from_dict(cls, dct: Mapping) -> IUTDataset:
        return cls(dct["snapshots"])

    @classmethod
    def from_web(cls) -> IUTDataset:
        """
        Get the IUTDataset from sec-certs.org
        """
        iut_resp = requests.get(config.fips_iut_dataset)
        if iut_resp.status_code != requests.codes.ok:
            raise ValueError(f"Getting IUT dataset failed: {iut_resp.status_code}")
        with NamedTemporaryFile(suffix=".json") as tmpfile:
            tmpfile.write(iut_resp.content)
            return cls.from_json(tmpfile.name)