diff options
| author | Tomáš Jusko | 2023-12-09 14:38:22 +0100 |
|---|---|---|
| committer | Tomáš Jusko | 2023-12-09 14:38:22 +0100 |
| commit | bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00 (patch) | |
| tree | e1ce4b5b2e0061d01657c4fdf9e5c9f710bdd228 /test | |
| parent | 51e001208bb8909879332286bb042445e79e1e5b (diff) | |
| download | pyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.tar.gz pyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.tar.zst pyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.zip | |
feat: Added pearson correlation coefficient to perf script
Diffstat (limited to 'test')
| -rw-r--r-- | test/sca/perf_stacked_combine.py | 71 |
1 files changed, 43 insertions, 28 deletions
diff --git a/test/sca/perf_stacked_combine.py b/test/sca/perf_stacked_combine.py index e64f847..024370e 100644 --- a/test/sca/perf_stacked_combine.py +++ b/test/sca/perf_stacked_combine.py @@ -10,6 +10,7 @@ import json import sys from typing import (Any, Callable, Dict, List, Optional, TextIO, Tuple, Union, cast) +from warnings import warn import numpy as np import numpy.random as npr @@ -30,6 +31,7 @@ TRACESET_OPS = { "variance": variance, "average_and_variance": average_and_variance, "add": add, + "pearson_corr": None, } OPERATIONS = list(TRACESET_OPS.keys()) @@ -40,8 +42,7 @@ DEVICES = ["cpu", "gpu"] def _generate_floating(rng: npr.Generator, - trace_count: int, - trace_length: int, + shape: tuple[int, ...], dtype: npt.DTypeLike = np.float32, distribution: str = "uniform", low: float = 0.0, @@ -55,7 +56,7 @@ def _generate_floating(rng: npr.Generator, or np.issubdtype(dtype, np.float64)) else np.float32) if distribution == "uniform": - samples = rng.random((trace_count, trace_length), + samples = rng.random(shape, dtype=dtype_) # type: ignore if (not np.issubdtype(dtype, np.float32) @@ -64,7 +65,7 @@ def _generate_floating(rng: npr.Generator, return (samples * (high - low) + low) elif distribution == "normal": return (rng - .normal(mean, std, (trace_count, trace_length)) + .normal(mean, std, shape) .clip(low, high) .astype(dtype)) @@ -72,8 +73,7 @@ def _generate_floating(rng: npr.Generator, def _generate_integers(rng: npr.Generator, - trace_count: int, - trace_length: int, + shape: tuple[int, ...], dtype: npt.DTypeLike = np.int32, distribution: str = "uniform", low: int = 0, @@ -86,11 +86,11 @@ def _generate_integers(rng: npr.Generator, if distribution == "uniform": return rng.integers(low, high, - size=(trace_count, trace_length), + size=shape, dtype=dtype) # type: ignore elif distribution == "normal": return (rng - .normal(mean, std, (trace_count, trace_length)) + .normal(mean, std, shape) .astype(dtype) .clip(low, high - 1)) @@ -105,8 +105,7 @@ def generate_dataset(rng: npr.Generator, low: float | int = 0, high: float | int = 1, mean: float | int = 0, - std: float | int = 1, - seed: int | None = None) -> np.ndarray: + std: float | int = 1) -> np.ndarray: """Generate a TraceSet with random samples For float dtype only float32 and float64 are supported natively, @@ -124,12 +123,12 @@ def generate_dataset(rng: npr.Generator, and not np.issubdtype(dtype, np.floating)): raise ValueError("dtype must be an integer or floating point type") + shape = (trace_count, trace_length) if trace_length > 1 else (trace_count,) gen_fun, cast_fun = ((_generate_integers, int) if np.issubdtype(dtype, np.integer) else (_generate_floating, float)) samples = gen_fun(rng, - trace_count, - trace_length, + shape, dtype, distribution, cast_fun(low), # type: ignore @@ -909,16 +908,16 @@ def _export_report_csv(time_storage: List[tuple[Namespace, aggr_writer) -def export_report(time_storage: List[tuple[Namespace, - List[List[TimeRecord]]]], +def export_report(args_durations: List[tuple[Namespace, + List[List[TimeRecord]]]], export_format: str, **kwargs) -> None: if export_format == "json": - _export_report_json(time_storage, **kwargs) + _export_report_json(args_durations, **kwargs) elif export_format == "csv": - _export_report_csv(time_storage, **kwargs) + _export_report_csv(args_durations, **kwargs) else: - raise ValueError("Unknown export format") + raise ValueError(f"Unknown export format {export_format}") def repetition(args: Namespace, @@ -937,11 +936,12 @@ def repetition(args: Namespace, args.low, args.high, args.mean, - args.std, - args.seed) + args.std) # Transform data for operations input - if args.stack: + if not args.stack: + data = to_traceset(dataset) + else: if args.verbose: print("Stacking data...") data = stack(dataset, @@ -950,8 +950,6 @@ def repetition(args: Namespace, args.time, time_storage, args.verbose) - else: - data = to_traceset(dataset) if not args.operations: print_times(time_storage) @@ -960,10 +958,10 @@ def repetition(args: Namespace, if args.verbose: print("Performing operations...") - # Operations on stacked traces if args.stack: - # Initialize trace manager + # Operations on stacked traces assert isinstance(data, StackedTraces) + # Initialize trace manager trace_manager = (CPUTraceManager(data) if args.device == "cpu" else GPUTraceManager( @@ -977,16 +975,32 @@ def repetition(args: Namespace, for op in args.operations: if args.verbose: print(f"Performing {op}...") - op_func = getattr(trace_manager, op) - timed(time_storage, args.verbose, args.time)(op_func)() + op_func = getattr(trace_manager, op, None) + if op_func is None: + warn(f"Unknown operation {op}") + continue + inputs = () + if op == "pearson_corr": + inputs = (generate_dataset(rng, + args.trace_count, + 1, + args.dtype, + args.distribution, + args.low, + args.high, + args.mean, + args.std),) + timed(time_storage, args.verbose, args.time)(op_func)(*inputs) else: assert isinstance(data, TraceSet) - # Perform operations for op in args.operations: if args.verbose: print(f"Performing {op}...") - op_func = TRACESET_OPS[op] + op_func = TRACESET_OPS.get(op) + if op_func is None: + warn(f"Unknown operation {op}") + continue timed(time_storage, args.verbose, args.time)(op_func)(*data) if args.verbose: @@ -1017,6 +1031,7 @@ def main(args: Namespace) -> List[List[TimeRecord]]: if __name__ == "__main__": args_list = _get_args(_get_parser()) + assert args_list results = [] try: for args in args_list: |
