"""Defines the network architectures."""
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as kl
from tensorflow.keras import regularizers
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from typing import List, Optional, Tuple
from . import tcn as tcn_layer
from .kapre.time_frequency import Spectrogram
from . import spec_utils
from das.morpholayers.layers import Closing2D, Opening2D
from das.morpholayers.regularizers import l1lattice
model_dict = dict()
def _register_as_model(func):
"""Adds func to model_dict Dict[modelname: modelfunc]. For selecting models by string."""
model_dict[func.__name__] = func
return func
[docs]@_register_as_model
def tcn(*args, **kwargs):
"""Synonym for tcn_stft."""
return tcn_stft(*args, **kwargs)
[docs]@_register_as_model
def tcn_stft(
nb_freq: int,
nb_classes: int,
nb_hist: int = 1,
nb_filters: int = 16,
kernel_size: int = 3,
nb_conv: int = 1,
loss: str = "categorical_crossentropy",
dilations: Optional[List[int]] = None,
activation: str = "norm_relu",
use_skip_connections: bool = True,
return_sequences: bool = True,
dropout_rate: float = 0.00,
padding: str = "same",
sample_weight_mode: str = None,
nb_pre_conv: int = 0,
pre_nb_dft: int = 64,
nb_lstm_units: int = 0,
morph_nb_kernels: int = 1,
morph_kernel_duration: int = 32,
learning_rate: float = 0.0005,
upsample: bool = True,
use_separable: bool = False,
use_resnet: bool = False,
compile: bool = True,
**kwignored,
):
"""Create TCN network with optional trainable STFT layer as pre-processing and downsampling frontend.
Args:
nb_freq (int): [description]
nb_classes (int): [description]
nb_hist (int, optional): [description]. Defaults to 1.
nb_filters (int, optional): [description]. Defaults to 16.
kernel_size (int, optional): [description]. Defaults to 3.
nb_conv (int, optional): [description]. Defaults to 1.
loss (str, optional): [description]. Defaults to "categorical_crossentropy".
dilations (List[int], optional): [description]. Defaults to [1, 2, 4, 8, 16].
activation (str, optional): [description]. Defaults to 'norm_relu'.
use_skip_connections (bool, optional): [description]. Defaults to True.
return_sequences (bool, optional): [description]. Defaults to True.
dropout_rate (float, optional): [description]. Defaults to 0.00.
padding (str, optional): [description]. Defaults to 'same'.
nb_pre_conv (int, optional): If >0 adds a single STFT layer with a hop size of 2**nb_pre_conv before the TCN.
Useful for speeding up training by reducing the sample rate early in the network.
Defaults to 0 (no downsampling)
pre_nb_dft (int, optional): Duration of filters (in samples) for the STFT frontend.
Number of filters is pre_nb_dft // 2 + 1.
Defaults to 64.
learning_rate (float, optional) Defaults to 0.0005
nb_lstm_units (int, optional): Defaults to 0.
morph_nb_kernels (int): Defaults to 0 (no morphological kernels).
morph_kernel_duration (int): Defaults to 32.
upsample (bool, optional): whether or not to restore the model output to the input samplerate.
Should generally be True during training and evaluation but may speed up inference.
Defaults to True.
use_separable (bool, optional): use separable convs in residual block. Defaults to False.
use_resnet (bool, optional): Defaults to False.
kwignored (Dict, optional): additional kw args in the param dict used for calling m(**params) to be ingonred
Returns:
[keras.models.Model]: Compiled TCN network model.
"""
if dilations is None:
dilations = [1, 2, 4, 8, 16]
input_layer = kl.Input(shape=(nb_hist, nb_freq))
out = input_layer
if nb_pre_conv > 0:
out = Spectrogram(
n_dft=pre_nb_dft,
n_hop=2**nb_pre_conv,
return_decibel_spectrogram=True,
power_spectrogram=1.0,
trainable_kernel=True,
name="trainable_stft",
image_data_format="channels_last",
)(out)
if not use_resnet:
out = kl.Reshape((out.shape[1], out.shape[2] * out.shape[3]))(out)
if use_resnet:
out = kl.Activation("relu")(out)
out = kl.Concatenate(axis=-1)([out, out, out])
out = ResNet50V2(input_shape=out.shape[1:], weights="imagenet", include_top=False)(out)
out = kl.Reshape((out.shape[1], out.shape[2] * out.shape[3]))(out)
out = kl.BatchNormalization()(out)
x = tcn_layer.TCN(
nb_filters=nb_filters,
kernel_size=kernel_size,
nb_stacks=nb_conv,
dilations=dilations,
activation=activation,
use_skip_connections=use_skip_connections,
padding=padding,
dropout_rate=dropout_rate,
return_sequences=return_sequences,
use_separable=use_separable,
)(out)
if nb_lstm_units > 0:
x = kl.Bidirectional(kl.LSTM(units=nb_lstm_units, return_sequences=True))(x)
x = kl.Dense(nb_classes, activation="softmax")(x)
if morph_nb_kernels > 0:
x = x[:, :, tf.newaxis, :]
x = Closing2D(
num_filters=morph_nb_kernels,
padding="same",
kernel_size=(morph_kernel_duration, 1),
kernel_regularization=l1lattice(0.002),
# kernel_constraint=ser,
)(x)
x = Opening2D(
num_filters=morph_nb_kernels,
padding="same",
kernel_size=(morph_kernel_duration, 1),
kernel_regularization=l1lattice(0.002),
# kernel_constraint=ser,
)(x)
x = x[..., 0, :]
if nb_pre_conv > 0 and upsample:
x = kl.UpSampling1D(size=2**nb_pre_conv)(x)
output_layer = x
model = keras.models.Model(input_layer, output_layer, name="TCN")
if use_resnet:
model.get_layer(name="trainable_stft").trainable = False
model.get_layer(name="resnet50v2").trainable = False
if compile:
optimizer = keras.optimizers.Adam(learning_rate=learning_rate, clipnorm=1.0)
model.compile(optimizer=optimizer, loss=loss, sample_weight_mode=sample_weight_mode)
return model
[docs]@_register_as_model
def stft_res_dense(
nb_freq: int,
nb_classes: int,
nb_hist: int = 1,
sample_weight_mode: str = None,
learning_rate: float = 0.0005,
compile: bool = True,
stft_compute: bool = False,
resnet_compute: bool = False,
resnet_train: bool = False,
label_smoothing: float = 0,
**kwignored,
):
"""Create TCN network with optional trainable STFT layer as pre-processing and downsampling frontend.
Args:
nb_freq (int): [description]
nb_classes (int): [description]
nb_hist (int, optional): [description]. Defaults to 1.
loss (str, optional): [description]. Defaults to "categorical_crossentropy".
nb_pre_conv (int, optional): If >0 adds a single STFT layer with a hop size of 2**nb_pre_conv before the TCN.
Useful for speeding up training by reducing the sample rate early in the network.
Defaults to 0 (no downsampling)
pre_nb_dft (int, optional): Duration of filters (in samples) for the STFT frontend.
Number of filters is pre_nb_dft // 2 + 1.
Defaults to 64.
learning_rate (float, optional) Defaults to 0.0005
stft_compute (bool, optional): Defaults to False.
resnet_compute (bool, optional): Defaults to False.
resnet_train (bool, optional): Fine tune resnet weights. Defaults to False.
kwignored (Dict, optional): additional kw args in the param dict used for calling m(**params) to be ingonred
Returns:
[keras.models.Model]: Compiled network model.
"""
input_layer = kl.Input(shape=(nb_hist, nb_freq))
out = input_layer
# TODO compute STFT
if resnet_compute:
out = tf.stack((out, out, out), axis=-1)
vision_model = ResNet50V2(input_shape=out.shape[1:], weights="imagenet", include_top=False)
out = vision_model(out, training=False)
out = kl.BatchNormalization()(out)
out = kl.TimeDistributed(
kl.Dense(min(32, 4 * nb_classes), activation="tanh", kernel_regularizer=regularizers.L1(1e-4))
)(out)
if len(out.shape) > 1:
out = kl.Flatten()(out)
# out = kl.BatchNormalization()(out)
# out = kl.Dropout(0.1)(out)
out = kl.Dense(min(32, 4 * nb_classes), activation="tanh", kernel_regularizer=regularizers.L1(1e-4))(out)
# out = kl.Dropout(0.1)(out)
out = kl.Dense(2 * nb_classes, activation="tanh", kernel_regularizer=regularizers.L1(1e-4))(out)
# out = kl.Dropout(0.1)(out)
out = kl.Dense(nb_classes, activation="softmax")(out)
output_layer = out
model = keras.models.Model(input_layer, output_layer, name="RES")
if resnet_compute and not resnet_train:
model.get_layer("resnet50v2").trainable = False
if compile:
optimizer = keras.optimizers.Adam(learning_rate=learning_rate, clipnorm=1.0)
model.compile(
optimizer=optimizer,
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
sample_weight_mode=sample_weight_mode,
)
return model
[docs]@_register_as_model
def tcn_stft_morph(
nb_freq: int,
nb_classes: int,
nb_hist: int = 1,
nb_filters: int = 16,
kernel_size: int = 3,
nb_conv: int = 1,
loss: str = "categorical_crossentropy",
dilations: Optional[List[int]] = None,
activation: str = "norm_relu",
use_skip_connections: bool = True,
return_sequences: bool = True,
dropout_rate: float = 0.00,
padding: str = "same",
sample_weight_mode: str = None,
nb_pre_conv: int = 0,
pre_nb_dft: int = 64,
nb_lstm_units: int = 0,
morph_nb_kernels: int = 1,
morph_kernel_duration: int = 33,
learning_rate: float = 0.0005,
upsample: bool = True,
use_separable: bool = False,
use_resnet: bool = False,
compile: bool = True,
**kwignored,
):
"""Create TCN network with optional trainable STFT layer as pre-processing and downsampling frontend.
Args:
nb_freq (int): [description]
nb_classes (int): [description]
nb_hist (int, optional): [description]. Defaults to 1.
nb_filters (int, optional): [description]. Defaults to 16.
kernel_size (int, optional): [description]. Defaults to 3.
nb_conv (int, optional): [description]. Defaults to 1.
loss (str, optional): [description]. Defaults to "categorical_crossentropy".
dilations (List[int], optional): [description]. Defaults to [1, 2, 4, 8, 16].
activation (str, optional): [description]. Defaults to 'norm_relu'.
use_skip_connections (bool, optional): [description]. Defaults to True.
return_sequences (bool, optional): [description]. Defaults to True.
dropout_rate (float, optional): [description]. Defaults to 0.00.
padding (str, optional): [description]. Defaults to 'same'.
nb_pre_conv (int, optional): If >0 adds a single STFT layer with a hop size of 2**nb_pre_conv before the TCN.
Useful for speeding up training by reducing the sample rate early in the network.
Defaults to 0 (no downsampling)
pre_nb_dft (int, optional): Duration of filters (in samples) for the STFT frontend.
Number of filters is pre_nb_dft // 2 + 1.
Defaults to 64.
learning_rate (float, optional) Defaults to 0.0005
nb_lstm_units (int, optional): Defaults to 0 (no lstm units).
morph_nb_kernels (int): Defaults to 1 (no morphological kernels).
morph_kernel_duration (int): Defaults to 33.
upsample (bool, optional): whether or not to restore the model output to the input samplerate.
Should generally be True during training and evaluation but may speed up inference.
Defaults to True.
use_separable (bool, optional): use separable convs in residual block. Defaults to False.
use_resnet (bool, optional): Defaults to False.
kwignored (Dict, optional): additional kw args in the param dict used for calling m(**params) to be ingonred
Returns:
[keras.models.Model]: Compiled TCN network model.
"""
if dilations is None:
dilations = [1, 2, 4, 8, 16]
input_layer = kl.Input(shape=(nb_hist, nb_freq))
out = input_layer
if nb_pre_conv > 0:
out = Spectrogram(
n_dft=pre_nb_dft,
n_hop=2**nb_pre_conv,
return_decibel_spectrogram=True,
power_spectrogram=1.0,
trainable_kernel=True,
name="trainable_stft",
image_data_format="channels_last",
)(out)
if not use_resnet:
out = kl.Reshape((out.shape[1], out.shape[2] * out.shape[3]))(out)
if use_resnet:
out = kl.Activation("relu")(out)
out = kl.Concatenate(axis=-1)([out, out, out])
out = ResNet50V2(input_shape=out.shape[1:], weights="imagenet", include_top=False)(out)
out = kl.Reshape((out.shape[1], out.shape[2] * out.shape[3]))(out)
out = kl.BatchNormalization()(out)
x = tcn_layer.TCN(
nb_filters=nb_filters,
kernel_size=kernel_size,
nb_stacks=nb_conv,
dilations=dilations,
activation=activation,
use_skip_connections=use_skip_connections,
padding=padding,
dropout_rate=dropout_rate,
return_sequences=return_sequences,
use_separable=use_separable,
)(out)
if nb_lstm_units > 0:
x = kl.Bidirectional(kl.LSTM(units=nb_lstm_units, return_sequences=True))(x)
x = kl.Dense(nb_classes, activation="softmax")(x)
if morph_nb_kernels > 0:
x = x[:, :, tf.newaxis, :]
x = Closing2D(
num_filters=morph_nb_kernels,
padding="same",
kernel_size=(morph_kernel_duration, 1),
kernel_regularization=l1lattice(0.002),
# kernel_constraint=ser,
)(x)
x = Opening2D(
num_filters=morph_nb_kernels,
padding="same",
kernel_size=(morph_kernel_duration, 1),
kernel_regularization=l1lattice(0.002),
# kernel_constraint=ser,
)(x)
x = x[..., 0, :]
if nb_pre_conv > 0 and upsample:
x = kl.UpSampling1D(size=2**nb_pre_conv)(x)
output_layer = x
model = keras.models.Model(input_layer, output_layer, name="TCN")
if use_resnet:
model.get_layer(name="trainable_stft").trainable = False
model.get_layer(name="resnet50v2").trainable = False
if compile:
optimizer = keras.optimizers.Adam(learning_rate=learning_rate, clipnorm=1.0)
model.compile(optimizer=optimizer, loss=loss, sample_weight_mode=sample_weight_mode)
return model