Train#
The network can be trained using three interfaces:
python, via
das.train.train
the command-line interface
das train
.the GUI - see the GUI tutorial
Training will:
load train/val/test data form a dataset
initialize the network
save all parameters for reproducibility
train the network and save the best network to disk
run inference and evaluate the network using the test data.
The names of files created during training start with an optional prefix and the time stamp of the start time of training, as in my-awesome-prefix_20192310_091032
. Typically, three files are created:
*_params.yaml
- training parameters etc.*_model.h5
- model architecture and weights*_results.h5
- predictions and evaluation results for the test set (only created if the training dataset contains a test set)
Training using python#
Training is done using the train
function in the das.train
module:
import das.train
help(das.train.train)
Help on function train in module das.train:
train(*, data_dir: str, y_suffix: str = '', save_dir: str = './', save_prefix: Union[str, NoneType] = None, model_name: str = 'tcn', nb_filters: int = 16, kernel_size: int = 16, nb_conv: int = 3, use_separable: List[bool] = False, nb_hist: int = 1024, ignore_boundaries: bool = True, batch_norm: bool = True, nb_pre_conv: int = 0, pre_nb_dft: int = 64, pre_kernel_size: int = 3, pre_nb_filters: int = 16, pre_nb_conv: int = 2, nb_lstm_units: int = 0, verbose: int = 2, batch_size: int = 32, nb_epoch: int = 400, learning_rate: Union[float, NoneType] = None, reduce_lr: bool = False, reduce_lr_patience: int = 5, fraction_data: Union[float, NoneType] = None, seed: Union[int, NoneType] = None, batch_level_subsampling: bool = False, tensorboard: bool = False, neptune_api_token: Union[str, NoneType] = None, neptune_project: Union[str, NoneType] = None, log_messages: bool = False, nb_stacks: int = 2, with_y_hist: bool = True, x_suffix: str = '', balance: bool = False, version_data: bool = True, _qt_progress: bool = False) -> Tuple[keras.engine.training.Model, Dict[str, Any]]
Train a DeepSS network.
Args:
data_dir (str): Path to the directory or file with the dataset for training.
Accepts npy-dirs (recommended), h5 files or zarr files.
See documentation for how the dataset should be organized.
y_suffix (str): Select training target by suffix.
Song-type specific targets can be created with a training dataset,
Defaults to '' (will use the standard target 'y')
save_dir (str): Directory to save training outputs.
The path of output files will constructed from the SAVE_DIR, an optional prefix, and the time stamp of the start of training.
Defaults to current directory ('./').
save_prefix (Optional[str]): Prepend to timestamp.
Name of files created will be SAVE_DIR/SAVE_PREFIX + "_" + TIMESTAMP
or SAVE_DIR/ TIMESTAMP if SAVE_PREFIX is empty.
Defaults to '' (empty).
model_name (str): Network architecture to use.
Use "tcn" (TCN) or "tcn_stft" (TCN with STFT frontend).
See das.models for a description of all models.
Defaults to 'tcn'.
nb_filters (int): Number of filters per layer.
Defaults to 16.
kernel_size (int): Duration of the filters (=kernels) in samples.
Defaults to 16.
nb_conv (int): Number of TCN blocks in the network.
Defaults to 3.
use_separable (List[bool]): Specify which TCN blocks should use separable convolutions.
Provide as a space-separated sequence of "False" or "True.
For instance: "True False False" will set the first block in a
three-block (as given by nb_conv) network to use separable convolutions.
Defaults to False (no block uses separable convolution).
nb_hist (int): Number of samples processed at once by the network (a.k.a chunk size).
Defaults to 1024.
ignore_boundaries (bool): Minimize edge effects by discarding predictions at the edges of chunks.
Defaults to True.
batch_norm (bool): Batch normalize.
Defaults to True.
nb_pre_conv (int): Downsampling rate. Adds downsampling frontend if not 0.
TCN_TCN: adds a frontend of N conv blocks (conv-relu-batchnorm-maxpool2) to the TCN.
TCN_STFT: adds a trainable STFT frontend.
Defaults to 0 (no frontend).
pre_nb_dft (int): Number of filters (roughly corresponding to filters) in the STFT frontend.
Defaults to 64.
pre_nb_filters (int): Number of filters per layer in the pre-processing TCN.
Defaults to 16.
pre_kernel_size (int): Duration of filters (=kernels) in samples in the pre-processing TCN.
Defaults to 3.
nb_lstm_units (int): If >0, adds LSTM with given number of units to the output of the stack of TCN blocks.
Defaults to 0 (no LSTM layer).
verbose (int): Verbosity of training output (0 - no output(?), 1 - progress bar, 2 - one line per epoch).
Defaults to 2.
batch_size (int): Batch size
Defaults to 32.
nb_epoch (int): Maximal number of training epochs.
Training will stop early if validation loss did not decrease in the last 20 epochs.
Defaults to 400.
learning_rate (Optional[float]): Learning rate of the model. Defaults should work in most cases.
Values typically range between 0.1 and 0.00001.
If None, uses per model defaults: "tcn" 0.0001, "tcn_stft" 0.0005).
Defaults to None.
reduce_lr (bool): Reduce learning rate on plateau.
Defaults to False.
reduce_lr_patience (int): Number of epochs w/o a reduction in validation loss after which to trigger a reduction in learning rate.
Defaults to 5.
fraction_data (Optional[float]): Fraction of training and validation to use for training.
Defaults to 1.0.
seed (Optional[int]): Random seed to reproducible select fractions of the data.
Defaults to None (no seed).
batch_level_subsampling (bool): Select fraction of data for training from random subset of shuffled batches.
If False, select a continuous chunk of the recording.
Defaults to False.
tensorboard (bool): Write tensorboard logs to save_dir. Defaults to False.
neptune_api_token (Optional[str]): API token for logging to neptune.ai. Defaults to None (no logging).
neptune_project (Optional[str]): Project to log to for neptune.ai. Defaults to None (no logging).
log_messages (bool): Sets logging level to INFO.
Defaults to False (will follow existing settings).
nb_stacks (int): Unused if model name is "tcn" or "tcn_stft". Defaults to 2.
with_y_hist (bool): Unused if model name is "tcn" or "tcn_stft". Defaults to True.
x_suffix (str): Select specific training data based on suffix (e.g. x_suffix).
Defaults to '' (will use the standard data 'x')
balance (bool): Balance data. Weights class-wise errors by the inverse of the class frequencies.
Defaults to False.
version_data (bool): Save MD5 hash of the data_dir to log and params.yaml.
Defaults to True (set to False for large datasets since it can be slow).
Returns
model (keras.Model)
params (Dict[str, Any])
Calling the train
function produce fairly verbose logging messages, to help with troubleshooting:
run time parameters
information on the size of the training and validation data
network architecture
training progress (training and validation loss)
after training, a classification report for the test data (if test data exist in the dataset)
When done, train
returns the trained keras model and a parameter dictionary with all arguments required to reproduce the model.
To demonstrate the outputs of train
, the following trains a small network on a small dataset to annotate pulse and sine song from Drosophila melanogater. Expected performance (f1-score) is about 75%.
model, params = das.train.train(model_name='tcn', # see `das.models` for valid model_names
data_dir='tutorial_dataset.npy',
save_dir='res',
nb_hist=256,
kernel_size=16,
nb_filters=16,
ignore_boundaries=True,
verbose=1,
nb_epoch=4,
log_messages=True)
INFO:root:Loading data from tutorial_dataset.npy.
INFO:root:Version of the data:
INFO:root: MD5 hash of tutorial_dataset.npy is
INFO:root: 34876fb30412a444e444a8e1f5312126
INFO:root:Parameters:
INFO:root:{'data_dir': 'tutorial_dataset.npy', 'y_suffix': '', 'save_dir': 'res', 'save_prefix': '', 'model_name': 'tcn', 'nb_filters': 16, 'kernel_size': 16, 'nb_conv': 3, 'use_separable': False, 'nb_hist': 256, 'ignore_boundaries': True, 'batch_norm': True, 'nb_pre_conv': 0, 'pre_nb_dft': 64, 'pre_kernel_size': 3, 'pre_nb_filters': 16, 'pre_nb_conv': 2, 'nb_lstm_units': 0, 'verbose': 1, 'batch_size': 32, 'nb_epoch': 4, 'reduce_lr': False, 'reduce_lr_patience': 5, 'fraction_data': None, 'seed': None, 'batch_level_subsampling': False, 'tensorboard': False, 'neptune_api_token': None, 'neptune_project': None, 'log_messages': True, 'nb_stacks': 2, 'with_y_hist': True, 'x_suffix': '', 'balance': False, 'version_data': True, 'sample_weight_mode': 'temporal', 'data_padding': 48, 'return_sequences': True, 'stride': 160, 'y_offset': 0, 'output_stride': 1, 'class_names': ['noise', 'pulse', 'sine'], 'class_names_pulse': ['noise', 'pulse'], 'class_names_sine': ['noise', 'sine'], 'class_types': ['segment', 'event', 'segment'], 'class_types_pulse': ['segment', 'event'], 'class_types_sine': ['segment', 'segment'], 'filename_endsample_test': [], 'filename_endsample_train': [], 'filename_endsample_val': [], 'filename_startsample_test': [], 'filename_startsample_train': [], 'filename_startsample_val': [], 'filename_train': [], 'filename_val': [], 'samplerate_x_Hz': 10000, 'samplerate_y_Hz': 10000, 'filename_test': [], 'data_hash': '34876fb30412a444e444a8e1f5312126', 'nb_freq': 1, 'nb_channels': 1, 'nb_classes': 3, 'first_sample_train': 0, 'last_sample_train': None, 'first_sample_val': 0, 'last_sample_val': None}
INFO:root:Preparing data
INFO:root:Training data:
INFO:root: AudioSequence with 3992 batches each with 32 items.
Total of 20440005 samples with
each x=(1,) and
each y=(3,)
INFO:root:Validation data:
INFO:root: AudioSequence with 812 batches each with 32 items.
Total of 4160001 samples with
each x=(1,) and
each y=(3,)
INFO:root:building network
/Users/janc/miniconda3/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py:355: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
warnings.warn(
INFO:root:None
INFO:root:Will save to res/20210924_220702.
INFO:root:start training
Model: "TCN"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 256, 1)] 0
__________________________________________________________________________________________________
conv1d (Conv1D) (None, 256, 16) 32 input_1[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 256, 16) 4112 conv1d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 256, 16) 0 conv1d_1[0][0]
__________________________________________________________________________________________________
lambda (Lambda) (None, 256, 16) 0 activation[0][0]
__________________________________________________________________________________________________
spatial_dropout1d (SpatialDropo (None, 256, 16) 0 lambda[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 256, 16) 272 spatial_dropout1d[0][0]
__________________________________________________________________________________________________
add (Add) (None, 256, 16) 0 conv1d[0][0]
conv1d_2[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 256, 16) 4112 add[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 256, 16) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 256, 16) 0 activation_1[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_1 (SpatialDro (None, 256, 16) 0 lambda_1[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 256, 16) 0 add[0][0]
conv1d_4[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D) (None, 256, 16) 4112 add_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 256, 16) 0 conv1d_5[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 256, 16) 0 activation_2[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_2 (SpatialDro (None, 256, 16) 0 lambda_2[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_2[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 256, 16) 0 add_1[0][0]
conv1d_6[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, 256, 16) 4112 add_2[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 256, 16) 0 conv1d_7[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, 256, 16) 0 activation_3[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_3 (SpatialDro (None, 256, 16) 0 lambda_3[0][0]
__________________________________________________________________________________________________
conv1d_8 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_3[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 256, 16) 0 add_2[0][0]
conv1d_8[0][0]
__________________________________________________________________________________________________
conv1d_9 (Conv1D) (None, 256, 16) 4112 add_3[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 256, 16) 0 conv1d_9[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, 256, 16) 0 activation_4[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_4 (SpatialDro (None, 256, 16) 0 lambda_4[0][0]
__________________________________________________________________________________________________
conv1d_10 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_4[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 256, 16) 0 add_3[0][0]
conv1d_10[0][0]
__________________________________________________________________________________________________
conv1d_11 (Conv1D) (None, 256, 16) 4112 add_4[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 256, 16) 0 conv1d_11[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda) (None, 256, 16) 0 activation_5[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_5 (SpatialDro (None, 256, 16) 0 lambda_5[0][0]
__________________________________________________________________________________________________
conv1d_12 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_5[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 256, 16) 0 add_4[0][0]
conv1d_12[0][0]
__________________________________________________________________________________________________
conv1d_13 (Conv1D) (None, 256, 16) 4112 add_5[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 256, 16) 0 conv1d_13[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda) (None, 256, 16) 0 activation_6[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_6 (SpatialDro (None, 256, 16) 0 lambda_6[0][0]
__________________________________________________________________________________________________
conv1d_14 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_6[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 256, 16) 0 add_5[0][0]
conv1d_14[0][0]
__________________________________________________________________________________________________
conv1d_15 (Conv1D) (None, 256, 16) 4112 add_6[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 256, 16) 0 conv1d_15[0][0]
__________________________________________________________________________________________________
lambda_7 (Lambda) (None, 256, 16) 0 activation_7[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_7 (SpatialDro (None, 256, 16) 0 lambda_7[0][0]
__________________________________________________________________________________________________
conv1d_16 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_7[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 256, 16) 0 add_6[0][0]
conv1d_16[0][0]
__________________________________________________________________________________________________
conv1d_17 (Conv1D) (None, 256, 16) 4112 add_7[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 256, 16) 0 conv1d_17[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda) (None, 256, 16) 0 activation_8[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_8 (SpatialDro (None, 256, 16) 0 lambda_8[0][0]
__________________________________________________________________________________________________
conv1d_18 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_8[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 256, 16) 0 add_7[0][0]
conv1d_18[0][0]
__________________________________________________________________________________________________
conv1d_19 (Conv1D) (None, 256, 16) 4112 add_8[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 256, 16) 0 conv1d_19[0][0]
__________________________________________________________________________________________________
lambda_9 (Lambda) (None, 256, 16) 0 activation_9[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_9 (SpatialDro (None, 256, 16) 0 lambda_9[0][0]
__________________________________________________________________________________________________
conv1d_20 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_9[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 256, 16) 0 add_8[0][0]
conv1d_20[0][0]
__________________________________________________________________________________________________
conv1d_21 (Conv1D) (None, 256, 16) 4112 add_9[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 256, 16) 0 conv1d_21[0][0]
__________________________________________________________________________________________________
lambda_10 (Lambda) (None, 256, 16) 0 activation_10[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_10 (SpatialDr (None, 256, 16) 0 lambda_10[0][0]
__________________________________________________________________________________________________
conv1d_22 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_10[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 256, 16) 0 add_9[0][0]
conv1d_22[0][0]
__________________________________________________________________________________________________
conv1d_23 (Conv1D) (None, 256, 16) 4112 add_10[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 256, 16) 0 conv1d_23[0][0]
__________________________________________________________________________________________________
lambda_11 (Lambda) (None, 256, 16) 0 activation_11[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_11 (SpatialDr (None, 256, 16) 0 lambda_11[0][0]
__________________________________________________________________________________________________
conv1d_24 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_11[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 256, 16) 0 add_10[0][0]
conv1d_24[0][0]
__________________________________________________________________________________________________
conv1d_25 (Conv1D) (None, 256, 16) 4112 add_11[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 256, 16) 0 conv1d_25[0][0]
__________________________________________________________________________________________________
lambda_12 (Lambda) (None, 256, 16) 0 activation_12[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_12 (SpatialDr (None, 256, 16) 0 lambda_12[0][0]
__________________________________________________________________________________________________
conv1d_26 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_12[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 256, 16) 0 add_11[0][0]
conv1d_26[0][0]
__________________________________________________________________________________________________
conv1d_27 (Conv1D) (None, 256, 16) 4112 add_12[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 256, 16) 0 conv1d_27[0][0]
__________________________________________________________________________________________________
lambda_13 (Lambda) (None, 256, 16) 0 activation_13[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_13 (SpatialDr (None, 256, 16) 0 lambda_13[0][0]
__________________________________________________________________________________________________
conv1d_28 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_13[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 256, 16) 0 add_12[0][0]
conv1d_28[0][0]
__________________________________________________________________________________________________
conv1d_29 (Conv1D) (None, 256, 16) 4112 add_13[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 256, 16) 0 conv1d_29[0][0]
__________________________________________________________________________________________________
lambda_14 (Lambda) (None, 256, 16) 0 activation_14[0][0]
__________________________________________________________________________________________________
spatial_dropout1d_14 (SpatialDr (None, 256, 16) 0 lambda_14[0][0]
__________________________________________________________________________________________________
conv1d_30 (Conv1D) (None, 256, 16) 272 spatial_dropout1d_14[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 256, 16) 0 conv1d_2[0][0]
conv1d_4[0][0]
conv1d_6[0][0]
conv1d_8[0][0]
conv1d_10[0][0]
conv1d_12[0][0]
conv1d_14[0][0]
conv1d_16[0][0]
conv1d_18[0][0]
conv1d_20[0][0]
conv1d_22[0][0]
conv1d_24[0][0]
conv1d_26[0][0]
conv1d_28[0][0]
conv1d_30[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 256, 16) 0 add_15[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 256, 3) 51 activation_15[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 256, 3) 0 dense[0][0]
==================================================================================================
Total params: 65,843
Trainable params: 65,843
Non-trainable params: 0
__________________________________________________________________________________________________
Epoch 1/4
1000/1000 [==============================] - ETA: 0s - batch: 499.5000 - size: 32.0000 - loss: 0.1143
/Users/janc/miniconda3/lib/python3.8/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
warnings.warn('`Model.state_updates` will be removed in a future version. '
Epoch 00001: val_loss improved from inf to 0.11043, saving model to res/20210924_220702_model.h5
/Users/janc/miniconda3/lib/python3.8/site-packages/keras/utils/generic_utils.py:494: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
warnings.warn('Custom mask layers require a config and must override '
1000/1000 [==============================] - 241s 236ms/step - batch: 499.5000 - size: 32.0000 - loss: 0.1143 - val_loss: 0.1104
Epoch 2/4
1000/1000 [==============================] - ETA: 0s - batch: 499.5000 - size: 32.0000 - loss: 0.0841
Epoch 00002: val_loss improved from 0.11043 to 0.10770, saving model to res/20210924_220702_model.h5
1000/1000 [==============================] - 226s 226ms/step - batch: 499.5000 - size: 32.0000 - loss: 0.0841 - val_loss: 0.1077
Epoch 3/4
1000/1000 [==============================] - ETA: 0s - batch: 499.5000 - size: 32.0000 - loss: 0.0867
Epoch 00003: val_loss did not improve from 0.10770
1000/1000 [==============================] - 215s 215ms/step - batch: 499.5000 - size: 32.0000 - loss: 0.0867 - val_loss: 0.1103
Epoch 4/4
1000/1000 [==============================] - ETA: 0s - batch: 499.5000 - size: 32.0000 - loss: 0.0823
Epoch 00004: val_loss improved from 0.10770 to 0.09988, saving model to res/20210924_220702_model.h5
1000/1000 [==============================] - 218s 218ms/step - batch: 499.5000 - size: 32.0000 - loss: 0.0823 - val_loss: 0.0999
INFO:root:re-loading last best model
INFO:root:predicting
INFO:root:evaluating
INFO:root:[[3545939 7799 38820]
[ 10658 33510 140]
[ 99569 58 241747]]
INFO:root:{'noise': {'precision': 0.9698517518077681, 'recall': 0.9870234523701497, 'f1-score': 0.9783622607234045, 'support': 3592558}, 'pulse': {'precision': 0.8100659946334035, 'recall': 0.7562968312720051, 'f1-score': 0.7822585351619492, 'support': 44308}, 'sine': {'precision': 0.8612075936830931, 'recall': 0.7081587935812335, 'f1-score': 0.7772203298284307, 'support': 341374}, 'accuracy': 0.960524251930502, 'macro avg': {'precision': 0.8803751133747548, 'recall': 0.8171596924077962, 'f1-score': 0.8459470419045948, 'support': 3978240}, 'weighted avg': {'precision': 0.9587493351198522, 'recall': 0.960524251930502, 'f1-score': 0.958918087071358, 'support': 3978240}}
INFO:root:saving to res/20210924_220702_results.h5.
/Users/janc/miniconda3/lib/python3.8/site-packages/tables/attributeset.py:464: NaturalNameWarning: object name is not a valid Python identifier: 'f1-score'; it does not match the pattern ``^[a-zA-Z_][a-zA-Z0-9_]*$``; you will not be able to use natural naming to access this object; using ``getattr()`` will still work, though
check_attribute_name(name)
/Users/janc/miniconda3/lib/python3.8/site-packages/tables/path.py:155: NaturalNameWarning: object name is not a valid Python identifier: 'macro avg'; it does not match the pattern ``^[a-zA-Z_][a-zA-Z0-9_]*$``; you will not be able to use natural naming to access this object; using ``getattr()`` will still work, though
check_attribute_name(name)
/Users/janc/miniconda3/lib/python3.8/site-packages/tables/path.py:155: NaturalNameWarning: object name is not a valid Python identifier: 'weighted avg'; it does not match the pattern ``^[a-zA-Z_][a-zA-Z0-9_]*$``; you will not be able to use natural naming to access this object; using ``getattr()`` will still work, though
check_attribute_name(name)
INFO:root:DONE.
Training using command-line scripts#
The training function das.train.train
and all its arguments are also accessible from the command line via das train
for use on the terminal. See here for a description of all command-line arguments. The command-line interface is generated with defopt.
For instance, training command above can be invoked from the command line:
das train --data-dir dat/dmel_single_raw.npy --save-dir res --model-name tcn --kernel-size 16 --nb-filters 16 --nb-hist 512 --nb-epoch 20 -i
Shell scripts are particularly useful if you want to fit the network with with different configurations to optimize structural parameters. For instance, this script will fit networks with different numbers of TCN blocks (nb_conv
) and filters (nb_filters
):
#!/bin/bash
conda activate das
YSUFFIX="pulse"
MODELNAME='tcn'
DATADIR='../dat/dmel_single.npy'
SAVEDIR="res"
NB_HIST=2048
KERNEL_SIZE=32
NB_FILTERS=32
NB_CONV=3
for NB_CONV in 2 3 4
do
for NB_FILTERS in 16 32 64
do
das train -i --nb-filters $NB_FILTERS --kernel-size $KERNEL_SIZE --nb-conv $NB_CONV --nb-hist $NB_HIST --save-dir $SAVEDIR --y-suffix $YSUFFIX --data-dir $DATADIR --model-name $MODELNAME
done
done
A description of all command line arguments can be obtained by typing das train --help
in a terminal:
!das train --help
usage: das train [-h] -d DATA_DIR [-y Y_SUFFIX] [--save-dir SAVE_DIR]
[--save-prefix SAVE_PREFIX] [-m MODEL_NAME]
[--nb-filters NB_FILTERS] [-k KERNEL_SIZE]
[--nb-conv NB_CONV] [-u [USE_SEPARABLE [USE_SEPARABLE ...]]]
[--nb-hist NB_HIST]
[-i | --ignore-boundaries | --no-ignore-boundaries]
[--batch-norm | --no-batch-norm] [--nb-pre-conv NB_PRE_CONV]
[--pre-nb-dft PRE_NB_DFT] [--pre-kernel-size PRE_KERNEL_SIZE]
[--pre-nb-filters PRE_NB_FILTERS] [--pre-nb-conv PRE_NB_CONV]
[--nb-lstm-units NB_LSTM_UNITS] [--verbose VERBOSE]
[--batch-size BATCH_SIZE] [--nb-epoch NB_EPOCH]
[--learning-rate LEARNING_RATE]
[--reduce-lr | --no-reduce-lr]
[--reduce-lr-patience REDUCE_LR_PATIENCE] [-f FRACTION_DATA]
[--seed SEED]
[--batch-level-subsampling | --no-batch-level-subsampling]
[-t | --tensorboard | --no-tensorboard]
[--neptune-api-token NEPTUNE_API_TOKEN]
[--neptune-project NEPTUNE_PROJECT]
[--log-messages | --no-log-messages] [--nb-stacks NB_STACKS]
[-w | --with-y-hist | --no-with-y-hist] [-x X_SUFFIX]
[--balance | --no-balance]
[--version-data | --no-version-data]
Train a DeepSS network.
optional arguments:
-h, --help show this help message and exit
-d DATA_DIR, --data-dir DATA_DIR
Path to the directory or file with the dataset for training.
Accepts npy-dirs (recommended), h5 files or zarr files.
See documentation for how the dataset should be organized.
-y Y_SUFFIX, --y-suffix Y_SUFFIX
Select training target by suffix.
Song-type specific targets can be created with a training dataset,
Defaults to '' (will use the standard target 'y')
--save-dir SAVE_DIR Directory to save training outputs.
The path of output files will constructed from the SAVE_DIR, an optional prefix, and the time stamp of the start of training.
Defaults to current directory ('./').
--save-prefix SAVE_PREFIX
Prepend to timestamp.
Name of files created will be SAVE_DIR/SAVE_PREFIX + "_" + TIMESTAMP
or SAVE_DIR/ TIMESTAMP if SAVE_PREFIX is empty.
Defaults to '' (empty).
-m MODEL_NAME, --model-name MODEL_NAME
Network architecture to use.
Use "tcn" (TCN) or "tcn_stft" (TCN with STFT frontend).
See das.models for a description of all models.
Defaults to 'tcn'.
--nb-filters NB_FILTERS
Number of filters per layer.
Defaults to 16.
-k KERNEL_SIZE, --kernel-size KERNEL_SIZE
Duration of the filters (=kernels) in samples.
Defaults to 16.
--nb-conv NB_CONV Number of TCN blocks in the network.
Defaults to 3.
-u [USE_SEPARABLE [USE_SEPARABLE ...]], --use-separable [USE_SEPARABLE [USE_SEPARABLE ...]]
Specify which TCN blocks should use separable convolutions.
Provide as a space-separated sequence of "False" or "True.
For instance: "True False False" will set the first block in a
three-block (as given by nb_conv) network to use separable convolutions.
Defaults to False (no block uses separable convolution).
--nb-hist NB_HIST Number of samples processed at once by the network (a.k.a chunk size).
Defaults to 1024.
-i, --ignore-boundaries, --no-ignore-boundaries
Minimize edge effects by discarding predictions at the edges of chunks.
Defaults to True.
--batch-norm, --no-batch-norm
Batch normalize.
Defaults to True.
--nb-pre-conv NB_PRE_CONV
Downsampling rate. Adds downsampling frontend if not 0.
TCN_TCN: adds a frontend of N conv blocks (conv-relu-batchnorm-maxpool2) to the TCN.
TCN_STFT: adds a trainable STFT frontend.
Defaults to 0 (no frontend).
--pre-nb-dft PRE_NB_DFT
Number of filters (roughly corresponding to filters) in the STFT frontend.
Defaults to 64.
--pre-kernel-size PRE_KERNEL_SIZE
Duration of filters (=kernels) in samples in the pre-processing TCN.
Defaults to 3.
--pre-nb-filters PRE_NB_FILTERS
Number of filters per layer in the pre-processing TCN.
Defaults to 16.
--pre-nb-conv PRE_NB_CONV
--nb-lstm-units NB_LSTM_UNITS
If >0, adds LSTM with given number of units to the output of the stack of TCN blocks.
Defaults to 0 (no LSTM layer).
--verbose VERBOSE Verbosity of training output (0 - no output(?), 1 - progress bar, 2 - one line per epoch).
Defaults to 2.
--batch-size BATCH_SIZE
Batch size
Defaults to 32.
--nb-epoch NB_EPOCH Maximal number of training epochs.
Training will stop early if validation loss did not decrease in the last 20 epochs.
Defaults to 400.
--learning-rate LEARNING_RATE
Learning rate of the model. Defaults should work in most cases.
Values typically range between 0.1 and 0.00001.
If None, uses per model defaults: "tcn" 0.0001, "tcn_stft" 0.0005).
Defaults to None.
--reduce-lr, --no-reduce-lr
Reduce learning rate on plateau.
Defaults to False.
--reduce-lr-patience REDUCE_LR_PATIENCE
Number of epochs w/o a reduction in validation loss after which to trigger a reduction in learning rate.
Defaults to 5.
-f FRACTION_DATA, --fraction-data FRACTION_DATA
Fraction of training and validation to use for training.
Defaults to 1.0.
--seed SEED Random seed to reproducible select fractions of the data.
Defaults to None (no seed).
--batch-level-subsampling, --no-batch-level-subsampling
Select fraction of data for training from random subset of shuffled batches.
If False, select a continuous chunk of the recording.
Defaults to False.
-t, --tensorboard, --no-tensorboard
Write tensorboard logs to save_dir. Defaults to False.
--neptune-api-token NEPTUNE_API_TOKEN
API token for logging to neptune.ai. Defaults to None (no logging).
--neptune-project NEPTUNE_PROJECT
Project to log to for neptune.ai. Defaults to None (no logging).
--log-messages, --no-log-messages
Sets logging level to INFO.
Defaults to False (will follow existing settings).
--nb-stacks NB_STACKS
Unused if model name is "tcn" or "tcn_stft". Defaults to 2.
-w, --with-y-hist, --no-with-y-hist
Unused if model name is "tcn" or "tcn_stft". Defaults to True.
-x X_SUFFIX, --x-suffix X_SUFFIX
Select specific training data based on suffix (e.g. x_suffix).
Defaults to '' (will use the standard data 'x')
--balance, --no-balance
Balance data. Weights class-wise errors by the inverse of the class frequencies.
Defaults to False.
--version-data, --no-version-data
Save MD5 hash of the data_dir to log and params.yaml.
Defaults to True (set to False for large datasets since it can be slow).