Source code for das.kapre.filterbank

# -*- coding: utf-8 -*-
from __future__ import absolute_import
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
from . import backend


# Todo: Filterbank(); init with mel, log, linear, etc.
# not parameterised, just a matrix multiplication


[docs]class Filterbank(Layer): """ ### `Filterbank` `kapre.filterbank.Filterbank(n_fbs, trainable_fb, sr=None, init='mel', fmin=0., fmax=None, bins_per_octave=12, image_data_format='default', **kwargs)` #### Notes Input/output are 2D image format. E.g., if channel_first, - input_shape: ``(None, n_ch, n_freqs, n_time)`` - output_shape: ``(None, n_ch, n_mels, n_time)`` #### Parameters * n_fbs: int - Number of filterbanks * sr: int - sampling rate. It is used to initialize ``freq_to_mel``. * init: str - if ``'mel'``, init with mel center frequencies and stds. * fmin: float - min frequency of filterbanks. - If `init == 'log'`, fmin should be > 0. Use `None` if you got no idea. * fmax: float - max frequency of filterbanks. - If `init == 'log'`, fmax is ignored. * trainable_fb: bool, - Whether the filterbanks are trainable or not. """ def __init__( self, n_fbs, trainable_fb, sr=None, init="mel", fmin=0.0, fmax=None, bins_per_octave=12, image_data_format="default", **kwargs ): """TODO: is sr necessary? is fmax necessary? init with None?""" self.supports_masking = True self.n_fbs = n_fbs assert init in ("mel", "log", "linear", "uni_random") if fmax is None: self.fmax = sr / 2.0 else: self.fmax = fmax if init in ("mel", "log"): assert sr is not None self.fmin = fmin self.init = init self.bins_per_octave = bins_per_octave self.sr = sr self.trainable_fb = trainable_fb assert image_data_format in ("default", "channels_first", "channels_last") if image_data_format == "default": self.image_data_format = K.image_data_format() else: self.image_data_format = image_data_format super(Filterbank, self).__init__(**kwargs)
[docs] def build(self, input_shape): if self.image_data_format == "channels_first": self.n_ch = input_shape[1] self.n_freq = input_shape[2] self.n_time = input_shape[3] else: self.n_ch = input_shape[3] self.n_freq = input_shape[1] self.n_time = input_shape[2] if self.init == "mel": self.filterbank = K.variable( backend.filterbank_mel( sr=self.sr, n_freq=self.n_freq, n_mels=self.n_fbs, fmin=self.fmin, fmax=self.fmax ).transpose(), dtype=K.floatx(), ) elif self.init == "log": self.filterbank = K.variable( backend.filterbank_log( sr=self.sr, n_freq=self.n_freq, n_bins=self.n_fbs, bins_per_octave=self.bins_per_octave, fmin=self.fmin ).transpose(), dtype=K.floatx(), ) if self.trainable_fb: self.trainable_weights.append(self.filterbank) else: self.non_trainable_weights.append(self.filterbank) super(Filterbank, self).build(input_shape) self.built = True
[docs] def compute_output_shape(self, input_shape): if self.image_data_format == "channels_first": return input_shape[0], self.n_ch, self.n_fbs, self.n_time else: return input_shape[0], self.n_fbs, self.n_time, self.n_ch
[docs] def call(self, x): # reshape so that the last axis is freq axis if self.image_data_format == "channels_first": x = K.permute_dimensions(x, [0, 1, 3, 2]) else: x = K.permute_dimensions(x, [0, 3, 2, 1]) output = K.dot(x, self.filterbank) # reshape back if self.image_data_format == "channels_first": return K.permute_dimensions(output, [0, 1, 3, 2]) else: return K.permute_dimensions(output, [0, 3, 2, 1])
[docs] def get_config(self): config = { "n_fbs": self.n_fbs, "sr": self.sr, "init": self.init, "fmin": self.fmin, "fmax": self.fmax, "bins_per_octave": self.bins_per_octave, "trainable_fb": self.trainable_fb, } base_config = super(Filterbank, self).get_config() return dict(list(base_config.items()) + list(config.items()))