aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorTomáš Jusko2023-12-09 14:38:22 +0100
committerTomáš Jusko2023-12-09 14:38:22 +0100
commitbd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00 (patch)
treee1ce4b5b2e0061d01657c4fdf9e5c9f710bdd228 /test
parent51e001208bb8909879332286bb042445e79e1e5b (diff)
downloadpyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.tar.gz
pyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.tar.zst
pyecsca-bd7f23b3d762ef48b60c871f7d3ffe5fe2da8c00.zip
feat: Added pearson correlation coefficient to perf script
Diffstat (limited to 'test')
-rw-r--r--test/sca/perf_stacked_combine.py71
1 files changed, 43 insertions, 28 deletions
diff --git a/test/sca/perf_stacked_combine.py b/test/sca/perf_stacked_combine.py
index e64f847..024370e 100644
--- a/test/sca/perf_stacked_combine.py
+++ b/test/sca/perf_stacked_combine.py
@@ -10,6 +10,7 @@ import json
import sys
from typing import (Any, Callable, Dict, List, Optional, TextIO,
Tuple, Union, cast)
+from warnings import warn
import numpy as np
import numpy.random as npr
@@ -30,6 +31,7 @@ TRACESET_OPS = {
"variance": variance,
"average_and_variance": average_and_variance,
"add": add,
+ "pearson_corr": None,
}
OPERATIONS = list(TRACESET_OPS.keys())
@@ -40,8 +42,7 @@ DEVICES = ["cpu", "gpu"]
def _generate_floating(rng: npr.Generator,
- trace_count: int,
- trace_length: int,
+ shape: tuple[int, ...],
dtype: npt.DTypeLike = np.float32,
distribution: str = "uniform",
low: float = 0.0,
@@ -55,7 +56,7 @@ def _generate_floating(rng: npr.Generator,
or np.issubdtype(dtype, np.float64))
else np.float32)
if distribution == "uniform":
- samples = rng.random((trace_count, trace_length),
+ samples = rng.random(shape,
dtype=dtype_) # type: ignore
if (not np.issubdtype(dtype, np.float32)
@@ -64,7 +65,7 @@ def _generate_floating(rng: npr.Generator,
return (samples * (high - low) + low)
elif distribution == "normal":
return (rng
- .normal(mean, std, (trace_count, trace_length))
+ .normal(mean, std, shape)
.clip(low, high)
.astype(dtype))
@@ -72,8 +73,7 @@ def _generate_floating(rng: npr.Generator,
def _generate_integers(rng: npr.Generator,
- trace_count: int,
- trace_length: int,
+ shape: tuple[int, ...],
dtype: npt.DTypeLike = np.int32,
distribution: str = "uniform",
low: int = 0,
@@ -86,11 +86,11 @@ def _generate_integers(rng: npr.Generator,
if distribution == "uniform":
return rng.integers(low,
high,
- size=(trace_count, trace_length),
+ size=shape,
dtype=dtype) # type: ignore
elif distribution == "normal":
return (rng
- .normal(mean, std, (trace_count, trace_length))
+ .normal(mean, std, shape)
.astype(dtype)
.clip(low, high - 1))
@@ -105,8 +105,7 @@ def generate_dataset(rng: npr.Generator,
low: float | int = 0,
high: float | int = 1,
mean: float | int = 0,
- std: float | int = 1,
- seed: int | None = None) -> np.ndarray:
+ std: float | int = 1) -> np.ndarray:
"""Generate a TraceSet with random samples
For float dtype only float32 and float64 are supported natively,
@@ -124,12 +123,12 @@ def generate_dataset(rng: npr.Generator,
and not np.issubdtype(dtype, np.floating)):
raise ValueError("dtype must be an integer or floating point type")
+ shape = (trace_count, trace_length) if trace_length > 1 else (trace_count,)
gen_fun, cast_fun = ((_generate_integers, int)
if np.issubdtype(dtype, np.integer) else
(_generate_floating, float))
samples = gen_fun(rng,
- trace_count,
- trace_length,
+ shape,
dtype,
distribution,
cast_fun(low), # type: ignore
@@ -909,16 +908,16 @@ def _export_report_csv(time_storage: List[tuple[Namespace,
aggr_writer)
-def export_report(time_storage: List[tuple[Namespace,
- List[List[TimeRecord]]]],
+def export_report(args_durations: List[tuple[Namespace,
+ List[List[TimeRecord]]]],
export_format: str,
**kwargs) -> None:
if export_format == "json":
- _export_report_json(time_storage, **kwargs)
+ _export_report_json(args_durations, **kwargs)
elif export_format == "csv":
- _export_report_csv(time_storage, **kwargs)
+ _export_report_csv(args_durations, **kwargs)
else:
- raise ValueError("Unknown export format")
+ raise ValueError(f"Unknown export format {export_format}")
def repetition(args: Namespace,
@@ -937,11 +936,12 @@ def repetition(args: Namespace,
args.low,
args.high,
args.mean,
- args.std,
- args.seed)
+ args.std)
# Transform data for operations input
- if args.stack:
+ if not args.stack:
+ data = to_traceset(dataset)
+ else:
if args.verbose:
print("Stacking data...")
data = stack(dataset,
@@ -950,8 +950,6 @@ def repetition(args: Namespace,
args.time,
time_storage,
args.verbose)
- else:
- data = to_traceset(dataset)
if not args.operations:
print_times(time_storage)
@@ -960,10 +958,10 @@ def repetition(args: Namespace,
if args.verbose:
print("Performing operations...")
- # Operations on stacked traces
if args.stack:
- # Initialize trace manager
+ # Operations on stacked traces
assert isinstance(data, StackedTraces)
+ # Initialize trace manager
trace_manager = (CPUTraceManager(data)
if args.device == "cpu"
else GPUTraceManager(
@@ -977,16 +975,32 @@ def repetition(args: Namespace,
for op in args.operations:
if args.verbose:
print(f"Performing {op}...")
- op_func = getattr(trace_manager, op)
- timed(time_storage, args.verbose, args.time)(op_func)()
+ op_func = getattr(trace_manager, op, None)
+ if op_func is None:
+ warn(f"Unknown operation {op}")
+ continue
+ inputs = ()
+ if op == "pearson_corr":
+ inputs = (generate_dataset(rng,
+ args.trace_count,
+ 1,
+ args.dtype,
+ args.distribution,
+ args.low,
+ args.high,
+ args.mean,
+ args.std),)
+ timed(time_storage, args.verbose, args.time)(op_func)(*inputs)
else:
assert isinstance(data, TraceSet)
-
# Perform operations
for op in args.operations:
if args.verbose:
print(f"Performing {op}...")
- op_func = TRACESET_OPS[op]
+ op_func = TRACESET_OPS.get(op)
+ if op_func is None:
+ warn(f"Unknown operation {op}")
+ continue
timed(time_storage, args.verbose, args.time)(op_func)(*data)
if args.verbose:
@@ -1017,6 +1031,7 @@ def main(args: Namespace) -> List[List[TimeRecord]]:
if __name__ == "__main__":
args_list = _get_args(_get_parser())
+ assert args_list
results = []
try:
for args in args_list: