aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/dataset/dataset.py
blob: 96dfd9f007ca48f4b9b8526d3aa0b2061e86d793 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
from __future__ import annotations

import logging
import shutil
import tarfile
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, ClassVar, Generic, TypeVar, cast

import pandas as pd
import requests
from packaging.version import parse as parse_version
from pydantic import AnyHttpUrl

from sec_certs._version import __version__
from sec_certs.dataset.auxiliary_dataset_handling import AuxiliaryDatasetHandler
from sec_certs.sample.certificate import Certificate
from sec_certs.serialization.json import (
    ComplexSerializableType,
    get_class_fullname,
    only_backed,
    serialize,
)
from sec_certs.utils import helpers
from sec_certs.utils.profiling import staged

logger = logging.getLogger(__name__)

CertSubType = TypeVar("CertSubType", bound=Certificate)
DatasetSubType = TypeVar("DatasetSubType", bound="Dataset")


class Dataset(Generic[CertSubType], ComplexSerializableType, ABC):
    """
    Base class for dataset of certificates from CC and FIPS 140 schemes. Layouts public
    functions, the processing pipeline and common operations on the dataset and certs.
    """

    FULL_ARCHIVE_URL: ClassVar[AnyHttpUrl]
    SNAPSHOT_URL: ClassVar[AnyHttpUrl]

    @dataclass
    class DatasetInternalState(ComplexSerializableType):
        meta_sources_parsed: bool = False
        artifacts_downloaded: bool = False
        pdfs_converted: bool = False
        auxiliary_datasets_processed: bool = False
        certs_analyzed: bool = False
        sec_certs_version: str | None = None

        def __init__(
            self,
            meta_sources_parsed: bool = False,
            artifacts_downloaded: bool = False,
            pdfs_converted: bool = False,
            auxiliary_datasets_processed: bool = False,
            certs_analyzed: bool = False,
            sec_certs_version: str | None = None,
        ):
            self.meta_sources_parsed = meta_sources_parsed
            self.artifacts_downloaded = artifacts_downloaded
            self.pdfs_converted = pdfs_converted
            self.auxiliary_datasets_processed = auxiliary_datasets_processed
            self.certs_analyzed = certs_analyzed
            self.sec_certs_version = sec_certs_version if sec_certs_version is not None else __version__

    def __init__(
        self,
        certs: dict[str, CertSubType] | None = None,
        root_dir: str | Path | None = None,
        name: str | None = None,
        description: str = "",
        state: DatasetInternalState | None = None,
        aux_handlers: dict[type[AuxiliaryDatasetHandler], AuxiliaryDatasetHandler] | None = None,
    ):
        super().__init__()
        self.certs = certs if certs is not None else {}
        self.timestamp = datetime.now()
        self.name = name if name else type(self).__name__
        self.description = description if description else datetime.now().strftime("%d/%m/%Y %H:%M:%S")
        self.state = state if state else self.DatasetInternalState()
        self.root_dir = Path(root_dir) if root_dir is not None else None  # type: ignore
        self.aux_handlers = aux_handlers if aux_handlers is not None else {}
        # Make sure that the auxiliary handlers (if supplied by the user) have the correct root_dir
        self._set_local_paths()

    @property
    def is_backed(self) -> bool:
        """
        Returns whether the dataset is backed by a directory.
        """
        return self.root_dir is not None

    @property
    def root_dir(self) -> Path:
        """
        Directory that will hold the serialized dataset files.
        """
        return self._root_dir  # type: ignore

    @root_dir.setter
    def root_dir(self, new_dir: str | Path | None) -> None:
        """
        This setter will only set the root dir and all internal paths so that they point
        to the new root dir. No data is being moved around.
        """
        if new_dir is None:
            self._root_dir = None
            return

        new_dir = Path(new_dir)
        if new_dir.is_file():
            raise ValueError(f"Root dir of {get_class_fullname(self)} cannot be a file.")

        self._root_dir = new_dir
        self._set_local_paths()

    @property
    @only_backed(throw=False)
    def web_dir(self) -> Path:
        """
        Path to certification-artifacts posted on web.
        """
        return self.root_dir / "web"

    @property
    @only_backed(throw=False)
    def auxiliary_datasets_dir(self) -> Path:
        """
        Path to directory with auxiliary datasets.
        """
        return self.root_dir / "auxiliary_datasets"

    @property
    @only_backed(throw=False)
    def certs_dir(self) -> Path:
        """
        Returns directory that holds files associated with certificates
        """
        return self.root_dir / "certs"

    @property
    @only_backed(throw=False)
    def json_path(self) -> Path:
        return self.root_dir / (self.name + ".json")

    def __contains__(self, item: object) -> bool:
        if not isinstance(item, Certificate):
            raise TypeError(
                f"You attempted to check if {type(item)} is member of {type(self)}, but only {type(Certificate)} are allowed to be members."
            )
        return item.dgst in self.certs

    def __iter__(self) -> Iterator[CertSubType]:
        yield from self.certs.values()

    def __getitem__(self, item: str) -> CertSubType:
        return self.certs.__getitem__(item.lower())

    def __setitem__(self, key: str, value: CertSubType):
        self.certs.__setitem__(key.lower(), value)

    def __len__(self) -> int:
        return len(self.certs)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Dataset):
            return NotImplemented
        return self.certs == other.certs

    def __str__(self) -> str:
        return str(type(self).__name__) + ":" + self.name + ", " + str(len(self)) + " certificates"

    @classmethod
    def from_web(  # noqa
        cls: type[DatasetSubType],
        archive_url: AnyHttpUrl | None = None,
        snapshot_url: AnyHttpUrl | None = None,
        progress_bar_desc: str | None = None,
        path: str | Path | None = None,
        auxiliary_datasets: bool = False,
        artifacts: bool = False,
    ) -> DatasetSubType:
        """
        Fetches the fresh dataset snapshot from sec-certs.org.

        Optionally stores it at the given path (a directory) and also downloads auxiliary datasets and artifacts (PDFs).

        .. note::
            Note that including the auxiliary datasets adds several gigabytes and including artifacts adds tens of gigabytes.

        :param archive_url: The URL of the full dataset archive. If `None` provided, defaults to `cls.FULL_ARCHIVE_URL`.
        :param snapshot_url: The URL of the full dataset snapshot. If `None` provided, defaults to `cls.SNAPSHOT_URL`.
        :param progress_bar_desc: Description of the download progress bar. If `None`, will pick reasonable default.
        :param path: Path to a directory where to store the dataset, or `None` if it should not be stored.
        :param auxiliary_datasets: Whether to also download auxiliary datasets (CVE, CPE, CPEMatch datasets).
        :param artifacts: Whether to also download artifacts (i.e. PDFs).
        """
        if not archive_url:
            archive_url = cls.FULL_ARCHIVE_URL
        if not snapshot_url:
            snapshot_url = cls.SNAPSHOT_URL
        if not progress_bar_desc:
            progress_bar_desc = f"Downloading: {cls.__name__}"

        if (artifacts or auxiliary_datasets) and path is None:
            raise ValueError("Path needs to be defined if artifacts or auxiliary datasets are to be downloaded.")
        if artifacts and not auxiliary_datasets:
            raise ValueError("Auxiliary datasets need to be downloaded if artifacts are to be downloaded.")
        if path is not None:
            path = Path(path)
            if not path.exists():
                path.mkdir(parents=True)
            if not path.is_dir():
                raise ValueError("Path needs to be a directory.")
        if artifacts:
            fsize = helpers.query_file_size(str(archive_url))
            base_tmpdir = tempfile.gettempdir() if fsize is None else helpers.tempdir_for(fsize)
            with tempfile.TemporaryDirectory(dir=base_tmpdir) as tmp_dir:
                dset_path = Path(tmp_dir) / "dataset.tar.gz"
                res = helpers.download_file(
                    str(archive_url),
                    dset_path,
                    show_progress_bar=True,
                    progress_bar_desc=progress_bar_desc,
                )
                if res != requests.codes.ok:
                    raise ValueError(f"Download failed: {res}")
                with tarfile.open(dset_path, "r:gz") as tar:
                    tar.extractall(str(path))
                dset = cls.from_json(path / "dataset.json")  # type: ignore
                if auxiliary_datasets:
                    dset.process_auxiliary_datasets(download_fresh=False)
        else:
            with tempfile.TemporaryDirectory() as tmp_dir:
                dset_path = Path(tmp_dir) / "dataset.json"
                helpers.download_file(
                    str(snapshot_url),
                    dset_path,
                    show_progress_bar=True,
                    progress_bar_desc=progress_bar_desc,
                )
                dset = cls.from_json(dset_path)
                if path:
                    dset.move_dataset(path)
                else:
                    # Clear the path, as it points to temporary file
                    dset._root_dir = None
            if auxiliary_datasets:
                dset.process_auxiliary_datasets(download_fresh=True)
        return dset

    def to_dict(self) -> dict[str, Any]:
        return {
            "state": self.state,
            "timestamp": self.timestamp,
            "name": self.name,
            "description": self.description,
            "n_certs": len(self),
            "certs": list(self.certs.values()),
        }

    @classmethod
    def from_dict(cls, dct: dict) -> Dataset:
        certs = {x.dgst: x for x in dct["certs"]}
        dset = cls(certs, name=dct["name"], description=dct["description"], state=dct["state"])
        if len(dset) != (claimed := dct["n_certs"]):
            logger.error(
                f"The actual number of certs in dataset ({len(dset)}) does not match the claimed number ({claimed})."
            )
        # Version check and warning
        try:
            from sec_certs._version import __version__ as current_version
        except ImportError:
            current_version = "unknown"
        dset_version = getattr(getattr(dset, "state", None), "sec_certs_version", None)
        if dset_version and current_version != "unknown" and dset_version != current_version:
            try:
                dset_v = parse_version(dset_version)
                curr_v = parse_version(current_version)
                if dset_v > curr_v:
                    which = "newer than"
                elif dset_v < curr_v:
                    which = "older than"
                else:
                    which = "equal to"
                logger.warning(
                    f"Dataset was created with sec-certs version {dset_version} ({which} your version {current_version}). To install the matching version: pip install sec-certs=={dset_version}"
                )
            except Exception:
                logger.warning(
                    f"Dataset was created with sec-certs version {dset_version}, but you are running version {current_version}. To install the matching version: pip install sec-certs=={dset_version}"
                )
        return dset

    @classmethod
    def from_json(cls: type[DatasetSubType], input_path: str | Path, is_compressed: bool = False) -> DatasetSubType:
        dset = cast(
            "DatasetSubType",
            ComplexSerializableType.from_json(input_path, is_compressed),
        )
        dset._root_dir = Path(input_path).parent.absolute()
        dset._set_local_paths()
        return dset

    def _set_local_paths(self) -> None:
        if self.root_dir is None:
            return
        if hasattr(self, "aux_handlers") and self.aux_handlers:
            for handler in self.aux_handlers.values():
                handler.set_local_paths(self.auxiliary_datasets_dir)

    @only_backed()
    def move_dataset(self, new_root_dir: str | Path) -> None:
        """
        Moves all dataset files to `new_root_dir` and adjusts all paths internally. Deletes the artifacts from the original location.

        :param str | Path new_root_dir: path to directory where the new dataset shall be stored.
        """
        new_root_dir = Path(new_root_dir)
        if new_root_dir.is_file():
            raise ValueError("New root dir must be a directory, not an existing file.")
        new_root_dir.mkdir(parents=True, exist_ok=True)

        shutil.copytree(str(self.root_dir), str(new_root_dir), dirs_exist_ok=True)
        shutil.rmtree(self.root_dir)
        self.root_dir = new_root_dir

    @only_backed()
    def copy_dataset(self, new_root_dir: str | Path) -> None:
        """
        Copies all dataset files to `new_root_dir` and adjusts all paths internally. Keeps the artifacts from the original location.

        :param str | Path new_root_dir: path to directory where the new dataset shall be stored.
        """
        new_root_dir = Path(new_root_dir)
        if new_root_dir.is_file():
            raise ValueError("New root dir must be a directory, not an existing file.")
        new_root_dir.mkdir(parents=True, exist_ok=True)

        shutil.copytree(str(self.root_dir), str(new_root_dir), dirs_exist_ok=True)
        self.root_dir = new_root_dir

    def get_certs_by_name(self, name: str) -> set[CertSubType]:
        """
        Returns list of certificates that match given name.
        """
        return {crt for crt in self if crt.name and crt.name == name}

    @abstractmethod
    def get_certs_from_web(self) -> None:
        raise NotImplementedError("Not meant to be implemented by the base class.")

    @staged(logger, "Processing auxiliary datasets")
    @serialize
    @only_backed()
    def process_auxiliary_datasets(self, download_fresh: bool = False, **kwargs) -> None:
        """
        Processes all auxiliary datasets (CPE, CVE, ...) that are required during computation.
        """
        logger.info("Processing auxiliary datasets.")
        for handler in self.aux_handlers.values():
            handler.process_dataset(download_fresh)
        self.state.auxiliary_datasets_processed = True

    @only_backed()
    def load_auxiliary_datasets(self) -> None:
        logger.info("Loading auxiliary datasets into memory.")
        for handler in self.aux_handlers.values():
            if not hasattr(handler, "dset"):
                try:
                    handler.load_dataset()
                except Exception:
                    logger.warning(
                        f"Failed to load auxiliary dataset bound to {handler}, some functionality may not work."
                    )

    @serialize
    @only_backed()
    def download_all_artifacts(self, fresh: bool = True) -> None:
        """
        Downloads all artifacts related to certification in the given scheme.
        """
        if not self.state.meta_sources_parsed:
            logger.error("Attempting to download pdfs while not having csv/html meta-sources parsed. Returning.")
            return

        logger.info("Attempting to download certification artifacts.")
        self._download_all_artifacts_body(fresh)
        if fresh:
            self._download_all_artifacts_body(False)

        self.state.artifacts_downloaded = True

    @abstractmethod
    def _download_all_artifacts_body(self, fresh: bool = True) -> None:
        raise NotImplementedError("Not meant to be implemented by the base class.")

    @serialize
    @only_backed()
    def convert_all_pdfs(self, fresh: bool = True) -> None:
        """
        Converts all pdf artifacts to txt, given the certification scheme.
        """
        if not self.state.artifacts_downloaded:
            logger.error("Attempting to convert pdfs while not having the artifacts downloaded. Returning.")
            return

        logger.info("Converting all PDFs to txt")
        self._convert_all_pdfs_body(fresh)

        self.state.pdfs_converted = True

    @abstractmethod
    def _convert_all_pdfs_body(self, fresh: bool = True) -> None:
        raise NotImplementedError("Not meant to be implemented by the base class.")

    @serialize
    @only_backed()
    def analyze_certificates(self) -> None:
        """
        Does two things:
            - Extracts data from certificates (keywords, etc.)
            - Computes various heuristics on the certificates.
        """
        if not self.state.pdfs_converted:
            logger.info(
                "Attempting run analysis of txt files while not having the pdf->txt conversion done. Returning."
            )
            return
        if not self.state.auxiliary_datasets_processed:
            logger.info(
                "Attempting to run analysis of certifies while not having the auxiliary datasets processed. Returning."
            )

        logger.info("Analyzing certificates.")
        self._analyze_certificates_body()
        self.state.certs_analyzed = True

    def _analyze_certificates_body(self) -> None:
        logger.info("Extracting data and heuristics")
        self.extract_data()
        self.compute_heuristics()

    @abstractmethod
    @only_backed()
    def extract_data(self) -> None:
        raise NotImplementedError("Not meant to be implemented by the base class.")

    @serialize
    @only_backed()
    def compute_heuristics(self) -> None:
        logger.info("Computing various heuristics from the certificates.")
        self.load_auxiliary_datasets()
        self._compute_heuristics_body()

    @abstractmethod
    def _compute_heuristics_body(self) -> None:
        raise NotImplementedError("Not meant to be implemented by the base class.")

    def get_keywords_df(self, var: str) -> pd.DataFrame:
        """
        Get dataframe of keyword hits for attribute (var) that is member of PdfData class.
        """
        data = [dict({"dgst": x.dgst}, **x.pdf_data.get_keywords_df_data(var)) for x in self]
        return pd.DataFrame(data).set_index("dgst")

    def update_with_certs(self, certs: list[CertSubType]) -> None:
        """
        Enriches the dataset with `certs`
        :param List[Certificate] certs: new certs to include into the dataset.
        """
        if any(x not in self for x in certs):
            logger.warning("Updating dataset with certificates outside of the dataset!")
        self.certs.update({x.dgst: x for x in certs})