Source code for das.io

"""
Load data for training/testing.
See doc/data.md for a description of the data schema.
"""
import os.path
from . import npy_dir
import zarr
import h5py
import numpy as np


[docs]class MemoryMappedDirectoryStore(zarr.storage.DirectoryStore): # faster access to zarr files via memmaping: https://gist.github.com/ivirshup/5c7df5ed10517abf6567a6a9af6c7eaa def _fromfile(self, fn): return memoryview(np.memmap(fn, mode="r"))
def _select(data, x_suffix, y_suffix): for lvl in ["test", "val", "train"]: if lvl in data: if "y_" + y_suffix in data[lvl]: data[lvl]["y"] = data[lvl]["y_" + y_suffix] if "eventtimes_" + y_suffix in data[lvl]: data[lvl]["eventtimes"] = data[lvl]["eventtimes_" + y_suffix] if "x_" + x_suffix in data[lvl]: data[lvl]["x"] = data[lvl]["x_" + x_suffix] if "eventtimes_" + x_suffix in data[lvl]: data[lvl]["eventtimes"] = data[lvl]["eventtimes_" + x_suffix] if f"samplerate_x_{x_suffix}_Hz" in data.attrs: data.attrs["samplerate_x_Hz"] = data.attrs[f"samplerate_x_{x_suffix}_Hz"] if "class_names_" + y_suffix in data.attrs and "class_types_" + y_suffix in data.attrs: data.attrs["class_names"] = data.attrs["class_names_" + y_suffix] data.attrs["class_types"] = data.attrs["class_types_" + y_suffix] return data def _to_dict(data): "Convert dict-like zarr or h5 store `data` to python dictionary." d = npy_dir.DictClass() d.attrs = dict(data.attrs) # cast to dict since data.attrs are read-only for zarr stores for key_top in data.keys(): d[key_top] = dict() for key, val in data[key_top].items(): d[key_top][key] = val return d
[docs]def load(location, x_suffix="", y_suffix=""): """Load data for training/testing from zarr store, npy directory, or hdf5 file. Args: location ([type]): [description] x_suffix, y_suffix (str, optional): alternative key for the training source and target (allows for different x/y's for the same y/x in one data file) Returns: dict-like complying with data schema defined above """ location = os.path.normpath(location) # remove trailing path separators if location.endswith(".zarr"): store = zarr.LRUStoreCache(MemoryMappedDirectoryStore(location), max_size=8e9) data = zarr.group(store=store, overwrite=False) elif location.endswith(".h5"): data = h5py.File.open(location, mode="r") elif location.endswith(".npy"): data = npy_dir.load(location) else: raise ValueError( f'Could not load data. Location {location} has unknown extension - needs to end either in ".zarr", ".npy", or ".h5".' ) data = _to_dict(data) if len(x_suffix) or len(y_suffix): data = _select(data, x_suffix, y_suffix) return data