aboutsummaryrefslogtreecommitdiff
path: root/test/utils.py
blob: 1276ad6393ae643c6104c947c4d2e709f29aa339 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pstats
import sys
import time

from pathlib import Path
from subprocess import run, PIPE, DEVNULL
from typing import Union, Literal

from pyinstrument import Profiler as PyProfiler
from cProfile import Profile as cProfiler


class RawTimer:
    start: int
    end: int
    duration: float

    def __enter__(self):
        self.start = time.perf_counter_ns()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.perf_counter_ns()
        self.duration = (self.end - self.start) / 1e9


class Profiler:
    def __init__(
        self,
        prof_type: Union[Literal["py"], Literal["c"], Literal["raw"]],
        output_directory: str,
        benchmark_name: str,
        operations: int = 0
    ):
        self._prof: Union[PyProfiler, cProfiler, RawTimer] = {
            "py": PyProfiler,
            "c": cProfiler,
            "raw": RawTimer,
        }[prof_type]()
        self._prof_type: Union[Literal["py"], Literal["c"], Literal["raw"]] = prof_type
        self._root_frame = None
        self._state = "out"
        self._output_directory = output_directory
        self._benchmark_name = benchmark_name
        self._operations = operations

    def __enter__(self):
        self._prof.__enter__()
        self._state = "in"
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._prof.__exit__(exc_type, exc_val, exc_tb)
        if self._prof_type == "py":
            self._root_frame = self._prof.last_session.root_frame()  # type: ignore
        self._state = "out"
        self.output()
        self.save()

    def save(self):
        if self._state != "out":
            raise ValueError
        if self._output_directory is None or self._benchmark_name is None:
            return
        git_commit = (
            run(
                ["git", "rev-parse", "--short", "HEAD"],
                stdout=PIPE,
                stderr=DEVNULL,
                check=False,
            )
            .stdout.strip()
            .decode()
        )
        git_dirty = (
            run(
                ["git", "diff", "--quiet"], stdout=DEVNULL, stderr=DEVNULL, check=False
            ).returncode
            != 0
        )
        version = git_commit + ("-dirty" if git_dirty else "")
        output_path = Path(self._output_directory) / (self._benchmark_name + ".csv")
        with output_path.open("a") as f:
            f.write(
                f"{version},{'.'.join(map(str, sys.version_info[:3]))},{self.get_time()}\n"
            )

    def output(self):
        if self._state != "out":
            raise ValueError
        if self._prof_type == "py":
            print(self._prof.output_text(unicode=True, color=True))  # type: ignore
        elif self._prof_type == "c":
            self._prof.print_stats("cumtime")  # type: ignore
        elif self._prof_type == "raw":
            print(f"{self._prof.duration:.4f}s {(self._operations/self._prof.duration) if self._operations else '-':.1f}op/s")  # type: ignore

    def get_time(self) -> float:
        if self._state != "out":
            raise ValueError
        if self._prof_type == "py":
            return self._root_frame.time  # type: ignore
        elif self._prof_type == "c":
            return pstats.Stats(self._prof).total_tt  # type: ignore
        elif self._prof_type == "raw":
            return self._prof.duration  # type: ignore