"""General utilities"""
import tensorflow.keras as keras
import logging
import time
import numpy as np
import yaml
import h5py
import scipy.signal
from . import kapre
from . import tcn
from . import models
from typing import Dict, Callable, Any, List, Tuple, Optional
[docs]def load_model(
file_trunk: str,
model_dict: Dict[str, Callable],
model_ext: str = "_model.h5",
params_ext: str = "_params.yaml",
compile: bool = True,
custom_objects: Optional[Dict[str, Callable]] = None,
):
"""Load model with weights.
First tries to load the full model directly using keras.models.load_model - this will likely fail for models with custom layers.
Second, try to init model from parameters and then add weights...
Args:
file_trunk (str): [description]
model_dict (Dict[str, Callable): [description]
model_ext (str, optional): [description]. Defaults to '_weights.h5'.
params_ext (str, optional): [description]. Defaults to '_params.yaml'.
compile (bool, optional): [description]. Defaults to True.
custom_objects (dict, optional): ...
Returns:
keras.Model
"""
if custom_objects is None:
custom_objects = {"Spectrogram": kapre.time_frequency.Spectrogram, "TCN": tcn.tcn_new.TCN}
try:
model_filename = _download_if_url(file_trunk + model_ext)
model = keras.models.load_model(model_filename, custom_objects=custom_objects)
except (SystemError, ValueError, AttributeError):
logging.debug(
"Failed to load model using keras, likely because it contains custom layers. Will try to init model architecture from code and load weights from `_model.h5` into it.",
exc_info=False,
)
logging.debug("", exc_info=True)
model = load_model_from_params(file_trunk, model_dict, weights_ext=model_ext, params_ext=params_ext, compile=compile)
return model
[docs]def load_model_from_params(
file_trunk: str,
model_dict: Dict[str, Callable],
weights_ext: str = "_model.h5",
params_ext: str = "_params.yaml",
compile: bool = True,
):
"""Init architecture from code and load model weights into it. Helps with model loading issues across TF versions.
Args:
file_trunk (str): [description]
models_dict ([type]): [description]
weights_ext (str, optional): [description]. Defaults to '_model.h5' (use weights from model file).
params_ext (str, optional): [description]. Defaults to '_params.yaml'.
compile (bool, optional): [description]. Defaults to True.
Returns:
keras.Model
"""
params = load_params(file_trunk, params_ext=params_ext)
# get the model - calls the function that generates a model with parameters
model = model_dict[params["model_name"]](**params)
weights_filename = _download_if_url(file_trunk + weights_ext)
model.load_weights(weights_filename)
if compile:
# Compile with random standard optimizer and loss so we can use the model for prediction
# Just re-compile the model if you want a particular optimizer and loss.
model.compile(optimizer=keras.optimizers.Adam(amsgrad=True), loss="mean_squared_error")
return model
[docs]def save_params(params: Dict[str, Any], file_trunk: str, params_ext: str = "_params.yaml"):
"""Save model/training parameters to yaml.
Args:
params (Dict[str]): [description]
file_trunk (str): [description]
params_ext (str, optional): [description]. Defaults to '_params.yaml'.
"""
with open(file_trunk + params_ext, "w") as f:
yaml.dump(params, f)
[docs]def load_params(file_trunk: str, params_ext: str = "_params.yaml") -> Dict[str, Any]:
"""Load model/training parameters from yaml
Args:
file_trunk (str): [description]
params_ext (strs, optional): [description]. Defaults to '_params.yaml'.
Returns:
Dict[str, Any]: Parameter dictionary
"""
filename = _download_if_url(file_trunk + params_ext)
with open(filename, "r") as f:
try:
params = yaml.unsafe_load(f)
except AttributeError:
params = yaml.load(f)
return params
[docs]def load_model_and_params(
model_save_name, model_dict=models.model_dict, custom_objects=None
) -> Tuple[keras.Model, Dict[str, Any]]:
"""[summary]
Args:
model_save_name ([type]): [description]
model_dict ([type], optional): [description]. Defaults to models.model_dict.
custom_objects
Returns:
keras.Model, Dict[str, Any]: [description]
"""
params = load_params(model_save_name)
model = load_model(model_save_name, model_dict=model_dict, custom_objects=custom_objects)
return model, params
def _download_if_url(url: str):
if not url.startswith("http"):
return url
else:
import urllib.request
import tempfile
from pathlib import Path
filename = url.split("/")[-1] # get filename
tmpdir = tempfile.mkdtemp()
local_path = Path(tmpdir) / filename
urllib.request.urlretrieve(url, local_path)
return local_path
[docs]def load_from(filename: str, datasets: List[str]):
"""Load datasets from h5 file.
Args:
filename (str)
datasets (List[str]): Names of the datasets (=keys) to load
Returns:
[type]: [description]
"""
data = dict()
with h5py.File(filename, "r") as f:
data = {dataset: f[dataset][:] for dataset in datasets}
return data
class Timer:
def __init__(self, verbose=False):
self.verbose = verbose
self.start = None
self.end = None
self.elapsed = None
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, *args):
self.end = time.perf_counter()
self.elapsed = self.end - self.start
if self.verbose:
print(self)
def __str__(self):
if self.start is None:
s = "Timer not started yet."
elif self.end is None:
s = "Timer still running."
elif self.elapsed is not None:
s = f"Time elapsed {self.elapsed:1.2f} seconds."
else:
s = "Timer in unexpected state."
return s
[docs]class QtProgressCallback(keras.callbacks.Callback):
def __init__(self, nb_epochs, comms):
"""Init the callback.
Args:
nb_epochs ([type]): number of training epochs
comms (tuple): tuple of (multiprocessing.Queue, threading.Event)
The queue is used to transmit progress updates to the GUI,
the event is set in the GUI to stop training.
"""
super().__init__()
self.nb_epochs = nb_epochs
self.queue = comms[0]
self.stop_event = comms[1]
def _check_if_stopped(self):
try:
if self.stop_event.is_set():
self.model.stop_training = True
except Exception as e:
print(e)
[docs] def on_train_begin(self, logs=None):
self.queue.put((0, "Starting training."))
[docs] def on_train_end(self, logs=None):
self.queue.put((-1, "Finishing training."))
[docs] def on_epoch_end(self, epoch, logs=None):
self.queue.put((epoch, f"Epoch {epoch}/{self.nb_epochs}"))
[docs] def on_train_batch_end(self, batch, logs=None):
self._check_if_stopped()
[docs] def on_test_batch_end(self, batch, logs=None):
self._check_if_stopped()
[docs] def on_predict_batch_end(self, batch, logs=None):
self._check_if_stopped()
[docs]def resample(x: np.array, fs_audio: float, fs_model: float):
"""Resample audio to model rate.
Rounds rates to next even number for efficiency.
Args:
x (np.array): _description_
fs_audio (float): _description_
fs_model (float): _description_
Returns:
np.array: Audio resample to fs_model.
"""
fs_audio_even = int(fs_audio // 2) * 2
fs_model_even = int(fs_model // 2) * 2
gcd = np.gcd(fs_audio_even, fs_model_even)
x = scipy.signal.resample_poly(x, fs_audio_even // gcd, fs_model_even // gcd, axis=0)
return x
[docs]def bandpass_filter_song(
x: np.ndarray, sampling_rate_hz: float, f_low: Optional[float] = None, f_high: Optional[float] = None
) -> np.ndarray:
"""Band-pass filter channel data
Args:
x (np.ndarray): Audio data[T,] or [T, nb_channels]
sampling_rate_hz (float): Sampling rate in Hz
f_low (Optional[float], optional): Lower cutoff in Hz. Defaults to 1.0 (None).
f_high (Optional[float], optional): Upper cutoff in Hz. Defaults to sampling_rate_hz/2 (None).
Returns:
np.ndarray: Sampled data
"""
if f_low is None:
f_low = 1.0
if f_high is None:
f_high = sampling_rate_hz / 2 - 1
f_high = min(f_high, sampling_rate_hz / 2 - 1)
sos_bp = scipy.signal.butter(5, [f_low, f_high], "bandpass", output="sos", fs=sampling_rate_hz)
x = scipy.signal.sosfiltfilt(sos_bp, x, axis=0)
return x