Source code for das.spec_utils

import logging
import tensorflow.keras as keras
import tensorflow as tf

try:
    import tensorflow_probability as tfp
except ImportError:
    tfp = None
    logging.debug("No tensorflow-probability.")
    pass


[docs]class MelSpec(keras.layers.Layer): # copied from https://keras.io/examples/audio/melgan_spectrogram_inversion def __init__( self, sampling_rate, frame_length=512, frame_step=None, fft_length=None, num_mel_channels=128, freq_min=0, freq_max=None, **kwargs, ): super().__init__(**kwargs) self.sampling_rate = sampling_rate self.frame_length = frame_length if frame_step is None: frame_step = self.frame_length // 4 self.frame_step = frame_step if fft_length is None: fft_length = self.frame_length * 2 self.fft_length = fft_length self.num_mel_channels = num_mel_channels self.freq_min = freq_min if freq_max is None: freq_max = self.sampling_rate // 2 self.freq_max = freq_max # Defining mel filter. This filter will be multiplied with the STFT output self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix( num_mel_bins=self.num_mel_channels, num_spectrogram_bins=self.fft_length // 2 + 1, sample_rate=self.sampling_rate, lower_edge_hertz=self.freq_min, upper_edge_hertz=self.freq_max, )
[docs] def call(self, audio, training=True): # We will only perform the transformation during training. # if training: # Taking the Short Time Fourier Transform. Ensure that the audio is padded. # In the paper, the STFT output is padded using the 'REFLECT' strategy. stft = tf.signal.stft( tf.squeeze(audio, -1), self.frame_length, self.frame_step, self.fft_length, pad_end=True, ) # Taking the magnitude of the STFT output magnitude = tf.abs(stft) # Multiplying the Mel-filterbank with the magnitude and scaling it using the db scale mel = tf.matmul(tf.square(magnitude), self.mel_filterbank) # # baseline subtract if tfp is not None: baseline = tfp.stats.percentile(mel, 50, axis=1, keepdims=True) mel = tf.math.log(mel / (baseline + 0.000001)) mel = tf.clip_by_value(mel, clip_value_min=0, clip_value_max=100_000) mel /= tf.reduce_max(mel) * 255 return mel
[docs] def get_config(self): config = super(MelSpec, self).get_config() config.update( { "frame_length": self.frame_length, "frame_step": self.frame_step, "fft_length": self.fft_length, "sampling_rate": self.sampling_rate, "num_mel_channels": self.num_mel_channels, "freq_min": self.freq_min, "freq_max": self.freq_max, } ) return config