Source code for das.pulse_utils

"""Utilities for handling pulses."""

import numpy as np
import scipy.signal as ss
from typing import List, Tuple


[docs]def normalize_pulse(pulse: np.ndarray, smooth_win: int = 15, flip_win: int = 10) -> np.ndarray: """Normalize pulses. 1. scales to unit-norm, 2. aligns to energy maximum, 3. flips so that pre-peak mean is positive Args: pulse (np.ndarray): should be [T,] smooth_win (int, optional): n samples of rect window used to smooth squared pulse for peak detection flip_win (int, optional): number of samples pre-peak used for determining sign of pulse for flipping. Returns: np.ndarray: normalized pulse [T,] """ # scale pulse /= np.linalg.norm(pulse) pulse_len = len(pulse) pulse_len_half = int(pulse_len / 2) # center gwin = ss.windows.boxcar(int(smooth_win)) pulse_env = np.convolve(pulse**2, gwin, mode="valid") offset = np.argmax(pulse_env) + int(np.ceil(smooth_win / 2)) + 1 pulse = np.pad(pulse, (len(pulse) - offset, offset), mode="constant", constant_values=0) # flip if np.sum(pulse[pulse_len - flip_win : pulse_len]) < 0: pulse *= -1 return pulse[pulse_len_half:-pulse_len_half]
[docs]def center_of_mass(x: np.ndarray, y: np.ndarray, thres: float = 0.5) -> float: """Calculate center of mass of y. Args: x (np.ndarray): y (np.ndarray): thres (float, optional): Threshold. Defaults to 0.5. Returns: float: Center of mass. """ y /= np.max(y) y -= thres y[y < 0] = 0 y /= np.sum(y) com = np.dot(x, y) return com
[docs]def pulse_freq( pulse: np.ndarray, fftlen: int = 1000, sampling_rate: int = 10000, mean_subtract: bool = True ) -> Tuple[float, np.ndarray, np.ndarray]: """Calculate pulse frequency as center of mass of the pulse's amplitude spectrum. Args: pulse (np.ndarray): Waveform (shape [T,]). fftlen (int, optional): Sets freq resolution of the spectrum. Defaults to 1_000. sampling_rate (float, optional): Sample rate of the pulse, in Hz. Defaults to 10_000. mean_subtract (bool, optional): If true, removes f0 component by mean subtraction. Defaults to True. Returns: Tuple[float, np.ndarray, np.ndarray]: Center frequency, frequency and amplitude values of the pulse spectrum (cut off at 1000 Hz). """ if mean_subtract: pulse -= np.mean(pulse) F = np.fft.rfftfreq(fftlen, 1 / sampling_rate) A = np.abs(np.fft.rfft(pulse, fftlen)) idx = int(np.argmax(F > 1_000)) center_freq = center_of_mass(F[:idx], A[:idx]) return center_freq, F[:idx], A[:idx]
[docs]def get_pulseshapes(pulsecenters: List[int], song: np.ndarray, win_hw: int) -> np.ndarray: """Extract waveforms around `pulsecenters` from `song`. In case of multi-channel recordings, will return the waveform on the channel with the maximum absolute value within +/-`win_hw` around the each pulsecenter. Args: pulsecenters (List[int]): Location of each pulse center in `song`, in samples song (np.ndarray): Audio data ([samples, channels]). win_hw (int): Half-width of the waveform cut out around each pulse center, in samples. Returns: np.ndarray: Extracted waveforms [2 * win_hw, nb_centers] """ pulseshapes = np.zeros((2 * win_hw, len(pulsecenters))) nb_channels = song.shape[1] for cnt, p in enumerate(pulsecenters): t0 = int(p - win_hw) t1 = int(p + win_hw) if t0 > 0 and t1 < song.shape[0]: if nb_channels > 1: tmp = song[t0:t1, :] loudest_channel = np.argmax(np.max(np.abs(tmp), axis=0)) pulseshapes[:, cnt] = tmp[:, loudest_channel].copy() else: pulseshapes[:, cnt] = song[t0:t1, 0].copy() return pulseshapes