aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca
diff options
context:
space:
mode:
authorJ08nY2020-03-09 17:58:41 +0100
committerJ08nY2020-03-09 17:58:41 +0100
commitf33ff9e95c3404be398e4c32f80cbf7adf03b981 (patch)
treecf173b6c00dbf3b40ea2f37932882541eaa0a480 /pyecsca
parent1615fcb61bbcab6e1f2ac5b9282aaf0a7a5978d8 (diff)
downloadpyecsca-f33ff9e95c3404be398e4c32f80cbf7adf03b981.tar.gz
pyecsca-f33ff9e95c3404be398e4c32f80cbf7adf03b981.tar.zst
pyecsca-f33ff9e95c3404be398e4c32f80cbf7adf03b981.zip
Fix alignment of traces of unequal length.
Diffstat (limited to 'pyecsca')
-rw-r--r--pyecsca/sca/trace/align.py6
-rw-r--r--pyecsca/sca/trace/trace.py2
-rw-r--r--pyecsca/sca/trace_set/hdf5.py19
3 files changed, 14 insertions, 13 deletions
diff --git a/pyecsca/sca/trace/align.py b/pyecsca/sca/trace/align.py
index f1bf259..dc246eb 100644
--- a/pyecsca/sca/trace/align.py
+++ b/pyecsca/sca/trace/align.py
@@ -191,8 +191,8 @@ def align_dtw_scale(reference: Trace, *traces: Trace, radius: int = 1,
dist, path = fastdtw(reference_samples, trace.samples, radius=radius)
else:
dist, path = dtw(reference_samples, trace.samples)
- result_samples = np.zeros(len(trace.samples), dtype=trace.samples.dtype)
- scale = np.ones(len(trace.samples), dtype=trace.samples.dtype)
+ result_samples = np.zeros(max((len(trace.samples), len(reference_samples))), dtype=trace.samples.dtype)
+ scale = np.ones(max((len(trace.samples), len(reference_samples))), dtype=trace.samples.dtype)
for x, y in path:
result_samples[x] = trace.samples[y]
scale[x] += 1
@@ -225,7 +225,7 @@ def align_dtw(reference: Trace, *traces: Trace, radius: int = 1, fast: bool = Tr
dist, path = fastdtw(reference_samples, trace.samples, radius=radius)
else:
dist, path = dtw(reference_samples, trace.samples)
- result_samples = np.zeros(len(trace.samples), dtype=trace.samples.dtype)
+ result_samples = np.zeros(max((len(trace.samples), len(reference_samples))), dtype=trace.samples.dtype)
pairs = np.array(np.array(path, dtype=np.dtype("int,int")),
dtype=np.dtype([("x", "int"), ("y", "int")]))
result_samples[pairs["x"]] = trace.samples[pairs["y"]]
diff --git a/pyecsca/sca/trace/trace.py b/pyecsca/sca/trace/trace.py
index 6ba70bc..764dada 100644
--- a/pyecsca/sca/trace/trace.py
+++ b/pyecsca/sca/trace/trace.py
@@ -64,7 +64,7 @@ class Trace(object):
return np.array_equal(self.samples, other.samples) and self.meta == other.meta
def with_samples(self, samples: ndarray) -> "Trace":
- return Trace(samples, deepcopy(self.meta), deepcopy(self.trace_set))
+ return Trace(samples, deepcopy(self.meta))
def __copy__(self):
return Trace(copy(self.samples), copy(self.meta), copy(self.trace_set))
diff --git a/pyecsca/sca/trace_set/hdf5.py b/pyecsca/sca/trace_set/hdf5.py
index 6f676ff..b610735 100644
--- a/pyecsca/sca/trace_set/hdf5.py
+++ b/pyecsca/sca/trace_set/hdf5.py
@@ -8,6 +8,7 @@ from typing import Union, Optional, Dict, Any, List
import h5py
import numpy as np
from public import public
+from copy import deepcopy
from .base import TraceSet
from .. import Trace
@@ -16,27 +17,27 @@ from .. import Trace
@public
class HDF5Meta(MutableMapping):
_dataset: h5py.AttributeManager
- _cache: Dict[str, Any]
def __init__(self, attrs: h5py.AttributeManager):
self._attrs = attrs
- self._cache = {}
super().__init__()
def __getitem__(self, item):
if item not in self._attrs:
raise KeyError
- if item not in self._cache:
- self._cache[item] = pickle.loads(self._attrs[item])
- return self._cache[item]
+ return pickle.loads(self._attrs[item])
def __setitem__(self, key, value):
self._attrs[key] = np.void(pickle.dumps(value))
def __delitem__(self, key):
del self._attrs[key]
- if key in self._cache:
- del self._cache[key]
+
+ def __copy__(self):
+ return deepcopy(self)
+
+ def __deepcopy__(self, memodict={}):
+ return dict(self)
def __iter__(self):
yield from self._attrs
@@ -69,7 +70,7 @@ class HDF5TraceSet(TraceSet):
else:
raise ValueError
kwargs = dict(hdf5.attrs)
- kwargs["_ordering"] = list(kwargs["_ordering"])
+ kwargs["_ordering"] = list(kwargs["_ordering"]) if "_ordering" in kwargs else list(hdf5.keys())
traces = []
for k in kwargs["_ordering"]:
meta = dict(HDF5Meta(hdf5[k].attrs))
@@ -87,7 +88,7 @@ class HDF5TraceSet(TraceSet):
else:
raise ValueError
kwargs = dict(hdf5.attrs)
- kwargs["_ordering"] = list(kwargs["_ordering"])
+ kwargs["_ordering"] = list(kwargs["_ordering"]) if "_ordering" in kwargs else list(hdf5.keys())
traces = []
for k in kwargs["_ordering"]:
meta = HDF5Meta(hdf5[k].attrs)