DL-Art-School/dlas/utils/audio_resampler.py

257 lines
12 KiB
Python

import numpy as np
import torch
from scipy import special
# Courtesy of https://www.kaggle.com/smallyellowduck/fast-audio-resampling-layer-in-pytorch
class AudioResampler(torch.nn.Module):
"""
Efficiently resample audio signals
This module is much faster than resampling with librosa because it exploits pytorch's efficient conv1d operations
This module is also faster than the existing pytorch resample function in
https://github.com/pytorch/audio/blob/b6a61c3f7d0267c77f8626167cc1eda0335f2753/torchaudio/compliance/kaldi.py#L892
Based on
https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
with improvements to include additional filter types and input parameters that align with the librosa api
"""
def __init__(self,
input_sr, output_sr, dtype,
num_zeros=64, cutoff_ratio=0.95, filter='kaiser', beta=14.0):
super().__init__() # init the base class
"""
This creates an object that can apply a symmetric FIR filter
based on torch.nn.functional.conv1d.
Args:
input_sr: The input sampling rate, AS AN INTEGER..
does not have to be the real sampling rate but should
have the correct ratio with output_sr.
output_sr: The output sampling rate, AS AN INTEGER.
It is the ratio with the input sampling rate that is
important here.
dtype: The torch dtype to use for computations (would be preferrable to
set things up so passing the dtype isn't necessary)
num_zeros: The number of zeros per side in the (sinc*hanning-window)
filter function. More is more accurate, but 64 is already
quite a lot. The kernel size is 2*num_zeros + 1.
cutoff_ratio: The filter rolloff point as a fraction of the
Nyquist frequency.
filter: one of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']
beta: parameter for 'kaiser' filter
You can think of this algorithm as dividing up the signals
(input,output) into blocks where there are `input_sr` input
samples and `output_sr` output samples. Then we treat it
using convolutional code, imagining there are `input_sr`
input channels and `output_sr` output channels per time step.
"""
assert isinstance(input_sr, int) and isinstance(output_sr, int)
if input_sr == output_sr:
self.resample_type = 'trivial'
return
def gcd(a, b):
""" Return the greatest common divisor of a and b"""
assert isinstance(a, int) and isinstance(b, int)
if b == 0:
return a
else:
return gcd(b, a % b)
d = gcd(input_sr, output_sr)
input_sr, output_sr = input_sr // d, output_sr // d
assert dtype in [torch.float32, torch.float64]
assert num_zeros > 3 # a reasonable bare minimum
np_dtype = np.float32 if dtype == torch.float32 else np.float64
assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast']
if filter == 'kaiser_best':
num_zeros = 64
beta = 14.769656459379492
cutoff_ratio = 0.9475937167399596
filter = 'kaiser'
elif filter == 'kaiser_fast':
num_zeros = 16
beta = 8.555504641634386
cutoff_ratio = 0.85
filter = 'kaiser'
# Define one 'block' of samples `input_sr` input samples
# and `output_sr` output samples. We can divide up
# the samples into these blocks and have the blocks be
# in correspondence.
# The sinc function will have, on average, `zeros_per_block`
# zeros per block.
zeros_per_block = min(input_sr, output_sr) * cutoff_ratio
# The convolutional kernel size will be n = (blocks_per_side*2 + 1),
# i.e. we add that many blocks on each side of the central block. The
# window radius (defined as distance from center to edge)
# is `blocks_per_side` blocks. This ensures that each sample in the
# central block can "see" all the samples in its window.
#
# Assuming the following division is not exact, adding 1
# will have the same effect as rounding up.
# blocks_per_side = 1 + int(num_zeros / zeros_per_block)
blocks_per_side = int(np.ceil(num_zeros / zeros_per_block))
kernel_width = 2 * blocks_per_side + 1
# We want the weights as used by torch's conv1d code; format is
# (out_channels, in_channels, kernel_width)
# https://pytorch.org/docs/stable/nn.functional.html
weights = torch.tensor(
(output_sr, input_sr, kernel_width), dtype=dtype)
# Computations involving time will be in units of 1 block. Actually this
# is the same as the `canonical` time axis since each block has input_sr
# input samples, so it would be one of whatever time unit we are using
window_radius_in_blocks = blocks_per_side
# The `times` below will end up being the args to the sinc function.
# For the shapes of the things below, look at the args to `view`. The terms
# below will get expanded to shape (output_sr, input_sr, kernel_width) through
# broadcasting
# We want it so that, assuming input_sr == output_sr, along the diagonal of
# the central block we have t == 0.
# The signs of the output_sr and input_sr terms need to be opposite. The
# sign that the kernel_width term needs to be will depend on whether it's
# convolution or correlation, and the logic is tricky.. I will just find
# which sign works.
times = (
np.arange(output_sr, dtype=np_dtype).reshape((output_sr, 1, 1)) / output_sr -
np.arange(input_sr, dtype=np_dtype).reshape((1, input_sr, 1)) / input_sr -
(np.arange(kernel_width, dtype=np_dtype).reshape((1, 1, kernel_width)) - blocks_per_side))
def hann_window(a):
"""
hann_window returns the Hann window on [-1,1], which is zero
if a < -1 or a > 1, and otherwise 0.5 + 0.5 cos(a*pi).
This is applied elementwise to a, which should be a NumPy array.
The heaviside function returns (a > 0 ? 1 : 0).
"""
return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi))
def kaiser_window(a, beta):
w = special.i0(
beta * np.sqrt(np.clip(1 - ((a - 0.0) / 1.0) ** 2.0, 0.0, 1.0))) / special.i0(beta)
return np.heaviside(1 - np.abs(a), 0.0) * w
# The weights below are a sinc function times a Hann-window function.
#
# Multiplication by zeros_per_block normalizes the sinc function
# (to compensate for scaling on the x-axis), so that the integral is 1.
#
# Division by input_sr normalizes the input function. Think of the input
# as a stream of dirac deltas passing through a low pass filter:
# in order to have the same magnitude as the original input function,
# we need to divide by the number of those deltas per unit time.
if filter == 'hann':
weights = (np.sinc(times * zeros_per_block)
* hann_window(times / window_radius_in_blocks)
* zeros_per_block / input_sr)
else:
weights = (np.sinc(times * zeros_per_block)
* kaiser_window(times / window_radius_in_blocks, beta)
* zeros_per_block / input_sr)
self.input_sr = input_sr
self.output_sr = output_sr
# weights has dim (output_sr, input_sr, kernel_width).
# If output_sr == 1, we can fold the input_sr into the
# kernel_width (i.e. have just 1 input channel); this will make the
# convolution faster and avoid unnecessary reshaping.
assert weights.shape == (output_sr, input_sr, kernel_width)
if output_sr == 1:
self.resample_type = 'integer_downsample'
self.padding = input_sr * blocks_per_side
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
self.weights = weights.transpose(1, 2).contiguous().view(
1, 1, input_sr * kernel_width)
elif input_sr == 1:
# In this case we'll be doing conv_transpose, so we want the same weights that
# we would have if we were *downsampling* by this factor-- i.e. as if input_sr,
# output_sr had been swapped.
self.resample_type = 'integer_upsample'
self.padding = output_sr * blocks_per_side
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
self.weights = weights.flip(2).transpose(
0, 2).contiguous().view(1, 1, output_sr * kernel_width)
else:
self.resample_type = 'general'
self.reshaped = False
self.padding = blocks_per_side
self.weights = torch.tensor(
weights, dtype=dtype, requires_grad=False)
self.weights = torch.nn.Parameter(self.weights, requires_grad=False)
@torch.no_grad()
def forward(self, data):
"""
Resample the data
Args:
input: a torch.Tensor with the same dtype as was passed to the
constructor.
There must be 2 axes, interpreted as (minibatch_size, sequence_length)...
the minibatch_size may in practice be the number of channels.
Return: Returns a torch.Tensor with the same dtype as the input, and
dimension (minibatch_size, (sequence_length//input_sr)*output_sr),
where input_sr and output_sr are the corresponding constructor args,
modified to remove any common factors.
"""
if self.resample_type == 'trivial':
return data
elif self.resample_type == 'integer_downsample':
(minibatch_size, seq_len) = data.shape
# will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1
data = data.unsqueeze(1)
data = torch.nn.functional.conv1d(data,
self.weights,
stride=self.input_sr,
padding=self.padding)
# shape will be (minibatch_size, out_channels = 1, seq_len);
# return as (minibatch_size, seq_len)
return data.squeeze(1)
elif self.resample_type == 'integer_upsample':
data = data.unsqueeze(1)
data = torch.nn.functional.conv_transpose1d(data,
self.weights,
stride=self.output_sr,
padding=self.padding)
return data.squeeze(1)
else:
assert self.resample_type == 'general'
(minibatch_size, seq_len) = data.shape
num_blocks = seq_len // self.input_sr
if num_blocks == 0:
# TODO: pad with zeros.
raise RuntimeError("Signal is too short to resample")
# data = data[:, 0:(num_blocks*self.input_sr)] # Truncate input
data = data[:, 0:(num_blocks * self.input_sr)
].view(minibatch_size, num_blocks, self.input_sr)
# Torch's conv1d expects input data with shape (minibatch, in_channels, time_steps), so transpose
data = data.transpose(1, 2)
data = torch.nn.functional.conv1d(data, self.weights,
padding=self.padding)
assert data.shape == (minibatch_size, self.output_sr, num_blocks)
return data.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr)