Source code for das.io
"""
Load data for training/testing.
See doc/data.md for a description of the data schema.
"""
import os.path
from . import npy_dir
def _select(data, x_suffix, y_suffix):
for lvl in ["test", "val", "train"]:
if lvl in data:
if "y_" + y_suffix in data[lvl]:
data[lvl]["y"] = data[lvl]["y_" + y_suffix]
if "eventtimes_" + y_suffix in data[lvl]:
data[lvl]["eventtimes"] = data[lvl]["eventtimes_" + y_suffix]
if "x_" + x_suffix in data[lvl]:
data[lvl]["x"] = data[lvl]["x_" + x_suffix]
if "eventtimes_" + x_suffix in data[lvl]:
data[lvl]["eventtimes"] = data[lvl]["eventtimes_" + x_suffix]
if f"samplerate_x_{x_suffix}_Hz" in data.attrs:
data.attrs["samplerate_x_Hz"] = data.attrs[f"samplerate_x_{x_suffix}_Hz"]
if "class_names_" + y_suffix in data.attrs and "class_types_" + y_suffix in data.attrs:
data.attrs["class_names"] = data.attrs["class_names_" + y_suffix]
data.attrs["class_types"] = data.attrs["class_types_" + y_suffix]
return data
def _to_dict(data):
"Convert dict-like zarr or h5 store `data` to python dictionary."
d = npy_dir.DictClass()
d.attrs = dict(data.attrs) # cast to dict since data.attrs are read-only for zarr stores
for key_top in data.keys():
d[key_top] = dict()
for key, val in data[key_top].items():
d[key_top][key] = val
return d
[docs]def load(location, x_suffix="", y_suffix=""):
"""Load data for training/testing from zarr store, npy directory, or hdf5 file.
Args:
location ([type]): [description]
x_suffix, y_suffix (str, optional): alternative key for the training source and target (allows for different x/y's for the same y/x in one data file)
Returns:
dict-like complying with data schema defined above
"""
location = os.path.normpath(location) # remove trailing path separators
if location.endswith(".zarr"):
import zarr
data = zarr.open(location, mode="r")
elif location.endswith(".h5"):
import h5py
data = h5py.File.open(location, mode="r")
elif location.endswith(".npy"):
data = npy_dir.load(location)
else:
raise ValueError(
f'Could not load data. Location {location} has unknown extension - needs to end either in ".zarr", ".npy", or ".h5".'
)
data = _to_dict(data)
if len(x_suffix) or len(y_suffix):
data = _select(data, x_suffix, y_suffix)
return data