forked from mrq/DL-Art-School
Stop dataset - attempt #2
This commit is contained in:
parent
17453ccbe8
commit
570ed327ed
|
@ -81,6 +81,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
from data.audio.stop_prediction_dataset import StopPredictionDataset as D
|
||||
elif mode == 'stop_prediction2':
|
||||
from data.audio.stop_prediction_dataset_2 import StopPredictionDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
92
codes/data/audio/stop_prediction_dataset_2.py
Normal file
92
codes/data/audio/stop_prediction_dataset_2.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
from munch import munchify
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
|
||||
from models.tacotron2 import hparams
|
||||
from models.tacotron2.layers import TacotronSTFT
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
|
||||
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
|
||||
class StopPredictionDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
label_compaction = opt_get(opt, ['label_compaction'], 1)
|
||||
hp = munchify(hparams.create_hparams())
|
||||
cache_path = os.path.join(path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
self.files = torch.load(cache_path)
|
||||
else:
|
||||
print("Building cache..")
|
||||
self.files = list(pathlib.Path(path).glob('*.wav'))
|
||||
torch.save(self.files, cache_path)
|
||||
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
|
||||
self.mel_length = 2000
|
||||
self.stft = TacotronSTFT(
|
||||
hp.filter_length, hp.hop_length, hp.win_length,
|
||||
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
|
||||
hp.mel_fmax)
|
||||
self.label_compaction = label_compaction
|
||||
|
||||
def __getitem__(self, index):
|
||||
audio, _ = load_wav_to_torch(self.files[index])
|
||||
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
|
||||
|
||||
if audio.std() > 1:
|
||||
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
|
||||
return None
|
||||
audio.clip_(-1, 1)
|
||||
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
|
||||
|
||||
# Form labels.
|
||||
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||||
for s in starts:
|
||||
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
|
||||
s = s // (256 * self.label_compaction)
|
||||
if s >= 2000//self.label_compaction:
|
||||
continue
|
||||
labels_start[s] = 1
|
||||
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||||
for e in ends:
|
||||
e = e // (256 * self.label_compaction)
|
||||
if e >= 2000//self.label_compaction:
|
||||
continue
|
||||
labels_end[e] = 1
|
||||
|
||||
return {
|
||||
'mels': mels,
|
||||
'labels_start': labels_start,
|
||||
'labels_end': labels_end,
|
||||
}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
|
||||
'label_compaction': 4,
|
||||
}
|
||||
ds = StopPredictionDataset(opt)
|
||||
j = 0
|
||||
for i in tqdm(range(100)):
|
||||
b = ds[random.randint(0, len(ds))]
|
||||
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
|
||||
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
|
||||
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
|
||||
for i in range(len(end_indices)):
|
||||
s = start_indices[i].item()*4
|
||||
e = end_indices[i].item()*4
|
||||
m = b['mels'][:, s:e]
|
||||
save_mel_buffer_to_file(m, f'{j}.npy')
|
||||
j += 1
|
|
@ -62,27 +62,26 @@ class GptSegmentor(nn.Module):
|
|||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.start_head = nn.Linear(model_dim, 1)
|
||||
self.stop_head = nn.Linear(model_dim, 1)
|
||||
|
||||
def forward(self, mel_inputs, termination_points=None):
|
||||
def forward(self, mel_inputs, start_labels=None, end_labels=None):
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
enc = self.gpt(mel_emb)
|
||||
stop_logits = self.final_norm(enc)
|
||||
stop_logits = self.stop_head(stop_logits)
|
||||
|
||||
if termination_points is not None:
|
||||
# The MEL gets decimated to 1/4 the size by the encoder, so we need to do the same to the termination points.
|
||||
termination_points = F.interpolate(termination_points.unsqueeze(1), size=mel_emb.shape[1], mode='area').squeeze()
|
||||
termination_points = (termination_points > 0).float()
|
||||
logits = self.final_norm(enc)
|
||||
stop_logits = self.stop_head(logits)
|
||||
start_logits = self.start_head(logits)
|
||||
|
||||
if start_labels is not None:
|
||||
# Compute loss
|
||||
loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), termination_points)
|
||||
return loss.mean()
|
||||
start_loss = F.binary_cross_entropy_with_logits(start_logits.squeeze(-1), start_labels.float())
|
||||
end_loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), end_labels.float())
|
||||
return start_loss.mean(), end_loss.mean()
|
||||
else:
|
||||
return stop_logits
|
||||
return start_logits, stop_logits
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ def create_hparams(hparams_string=None, verbose=False):
|
|||
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
|
||||
sampling_rate=22050,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
hop_length=256, # This means a MEL is 1/256th the equivalent audio.
|
||||
win_length=1024,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0.0,
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Combines all libriTTS WAV->text mappings into a single file
|
||||
import os
|
||||
import random
|
||||
|
||||
import audio2numpy
|
||||
import torch
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.audio_resampler import AudioResampler
|
||||
|
||||
|
||||
def secs_to_frames(secs, sr):
|
||||
return int(secs*sr)
|
||||
|
||||
|
||||
def get_audio_clip(audio, sr, start, end):
|
||||
start = secs_to_frames(start, sr)
|
||||
end = secs_to_frames(end, sr)
|
||||
assert end > start
|
||||
if end >= audio.shape[0]:
|
||||
return None
|
||||
return audio[start:end]
|
||||
|
||||
|
||||
# Produces an audio clip that would produce a MEL spectrogram of length mel_length by parsing parsed_sentences starting
|
||||
# at starting_index and moving forwards until the full length is finished.
|
||||
# Returns:
|
||||
# On failure, returns tuple: (end_index, None, [], [])
|
||||
# On success: returns tuple: (end_index, clip, start_points, end_points)
|
||||
# clip.shape = (<mel_length*256>,)
|
||||
# start_points = list(ints) where each sentence in the clip starts
|
||||
# end_points = list(ints) where each sentence in the clip ends
|
||||
def gather_clip(audio, parsed_sentences, starting_index, sr, mel_length):
|
||||
audio_length = (mel_length * 256) / sr # This is technically a hyperparameter, but I have no intent of changing the MEL hop length.
|
||||
starts = []
|
||||
ends = []
|
||||
start, end = parsed_sentences[starting_index][4:6]
|
||||
start = float(start)
|
||||
end = float(end)
|
||||
clipstart = max(start - random.random() * 2, 0) # Offset start backwards by up to 2 seconds
|
||||
clipend = start + audio_length
|
||||
clip = get_audio_clip(audio, sr, clipstart, clipend)
|
||||
if clip is not None:
|
||||
# Fetch the start and endpoints that go along with this clip.
|
||||
starts.append(secs_to_frames(start-clipstart, sr))
|
||||
while end < clipend:
|
||||
ends.append(secs_to_frames(end-clipstart, sr))
|
||||
starting_index += 1
|
||||
if starting_index >= len(parsed_sentences):
|
||||
break
|
||||
start, end = parsed_sentences[starting_index][4:6]
|
||||
start = float(start)
|
||||
end = float(end)
|
||||
if start < clipend:
|
||||
starts.append(secs_to_frames(start-clipstart, sr))
|
||||
|
||||
return starting_index+1, clip, starts, ends
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
full_book_root = 'D:\\data\\audio\\libritts\\full_books\\mp3'
|
||||
libri_root = 'D:\\data\\audio\\libritts\\test-clean'
|
||||
desired_mel_length = 2000
|
||||
desired_audio_sample_rate = 22050
|
||||
output_dir = 'D:\\data\\audio\\libritts\\stop_dataset_eval'
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
j = 0
|
||||
readers = os.listdir(libri_root)
|
||||
for it, reader_dir in enumerate(tqdm(readers)):
|
||||
#if it <= 145: # Hey idiot! If you change this, change j too!
|
||||
# continue
|
||||
reader = os.path.join(libri_root, reader_dir)
|
||||
if not os.path.isdir(reader):
|
||||
continue
|
||||
for chapter_dir in os.listdir(reader):
|
||||
chapter = os.path.join(reader, chapter_dir)
|
||||
if not os.path.isdir(chapter):
|
||||
continue
|
||||
id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}'
|
||||
book_file = os.path.join(chapter, f'{id}.book.tsv')
|
||||
if not os.path.exists(book_file):
|
||||
continue
|
||||
with open(book_file, encoding='utf-8') as f:
|
||||
full_chapter, sr = audio2numpy.open_audio(os.path.join(full_book_root, reader_dir, chapter_dir, f'{chapter_dir}.mp3'))
|
||||
full_chapter = torch.tensor(full_chapter)
|
||||
if len(full_chapter.shape) > 1:
|
||||
full_chapter = full_chapter[:, 0] # Only use mono-audio.
|
||||
resampler = AudioResampler(sr, desired_audio_sample_rate, dtype=torch.float)
|
||||
full_chapter = resampler(full_chapter.unsqueeze(0)).squeeze(0)
|
||||
parsed_sentences = [line.strip().split('\t') for line in f]
|
||||
i = 0
|
||||
while i < len(parsed_sentences):
|
||||
i, clip, ns, ne = gather_clip(full_chapter, parsed_sentences, i, desired_audio_sample_rate, desired_mel_length)
|
||||
if clip is not None:
|
||||
wavfile.write(os.path.join(output_dir, f'{j}.wav'), desired_audio_sample_rate, clip.cpu().numpy())
|
||||
torch.save((ns,ne), os.path.join(output_dir, f'{j}_se.pth'))
|
||||
j += 1
|
|
@ -1,6 +1,9 @@
|
|||
import pathlib
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.waveglow.waveglow import WaveGlow
|
||||
|
||||
|
@ -21,8 +24,12 @@ class Vocoder:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inp = '3.npy'
|
||||
mel = torch.tensor(numpy.load(inp)).to('cuda')
|
||||
vocoder = Vocoder()
|
||||
wav = vocoder.transform_mel_to_audio(mel)
|
||||
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())
|
||||
path = 'data/audio'
|
||||
files = list(pathlib.Path(path).glob('*.npy'))
|
||||
|
||||
for inp in tqdm(files):
|
||||
inp = str(inp)
|
||||
mel = torch.tensor(numpy.load(inp)).to('cuda')
|
||||
vocoder = Vocoder()
|
||||
wav = vocoder.transform_mel_to_audio(mel)
|
||||
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())
|
|
@ -282,7 +282,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
249
codes/utils/audio_resampler.py
Normal file
249
codes/utils/audio_resampler.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
# Courtesy of https://www.kaggle.com/smallyellowduck/fast-audio-resampling-layer-in-pytorch
|
||||
class AudioResampler(torch.nn.Module):
|
||||
"""
|
||||
Efficiently resample audio signals
|
||||
This module is much faster than resampling with librosa because it exploits pytorch's efficient conv1d operations
|
||||
This module is also faster than the existing pytorch resample function in
|
||||
https://github.com/pytorch/audio/blob/b6a61c3f7d0267c77f8626167cc1eda0335f2753/torchaudio/compliance/kaldi.py#L892
|
||||
|
||||
Based on
|
||||
https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
|
||||
with improvements to include additional filter types and input parameters that align with the librosa api
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_sr, output_sr, dtype,
|
||||
num_zeros=64, cutoff_ratio=0.95, filter='kaiser', beta=14.0):
|
||||
super().__init__() # init the base class
|
||||
"""
|
||||
This creates an object that can apply a symmetric FIR filter
|
||||
based on torch.nn.functional.conv1d.
|
||||
|
||||
Args:
|
||||
input_sr: The input sampling rate, AS AN INTEGER..
|
||||
does not have to be the real sampling rate but should
|
||||
have the correct ratio with output_sr.
|
||||
output_sr: The output sampling rate, AS AN INTEGER.
|
||||
It is the ratio with the input sampling rate that is
|
||||
important here.
|
||||
dtype: The torch dtype to use for computations (would be preferrable to
|
||||
set things up so passing the dtype isn't necessary)
|
||||
num_zeros: The number of zeros per side in the (sinc*hanning-window)
|
||||
filter function. More is more accurate, but 64 is already
|
||||
quite a lot. The kernel size is 2*num_zeros + 1.
|
||||
cutoff_ratio: The filter rolloff point as a fraction of the
|
||||
Nyquist frequency.
|
||||
filter: one of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']
|
||||
beta: parameter for 'kaiser' filter
|
||||
|
||||
You can think of this algorithm as dividing up the signals
|
||||
(input,output) into blocks where there are `input_sr` input
|
||||
samples and `output_sr` output samples. Then we treat it
|
||||
using convolutional code, imagining there are `input_sr`
|
||||
input channels and `output_sr` output channels per time step.
|
||||
|
||||
"""
|
||||
assert isinstance(input_sr, int) and isinstance(output_sr, int)
|
||||
if input_sr == output_sr:
|
||||
self.resample_type = 'trivial'
|
||||
return
|
||||
|
||||
def gcd(a, b):
|
||||
""" Return the greatest common divisor of a and b"""
|
||||
assert isinstance(a, int) and isinstance(b, int)
|
||||
if b == 0:
|
||||
return a
|
||||
else:
|
||||
return gcd(b, a % b)
|
||||
|
||||
d = gcd(input_sr, output_sr)
|
||||
input_sr, output_sr = input_sr // d, output_sr // d
|
||||
|
||||
assert dtype in [torch.float32, torch.float64]
|
||||
assert num_zeros > 3 # a reasonable bare minimum
|
||||
np_dtype = np.float32 if dtype == torch.float32 else np.float64
|
||||
|
||||
assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast']
|
||||
|
||||
if filter == 'kaiser_best':
|
||||
num_zeros = 64
|
||||
beta = 14.769656459379492
|
||||
cutoff_ratio = 0.9475937167399596
|
||||
filter = 'kaiser'
|
||||
elif filter == 'kaiser_fast':
|
||||
num_zeros = 16
|
||||
beta = 8.555504641634386
|
||||
cutoff_ratio = 0.85
|
||||
filter = 'kaiser'
|
||||
|
||||
# Define one 'block' of samples `input_sr` input samples
|
||||
# and `output_sr` output samples. We can divide up
|
||||
# the samples into these blocks and have the blocks be
|
||||
# in correspondence.
|
||||
|
||||
# The sinc function will have, on average, `zeros_per_block`
|
||||
# zeros per block.
|
||||
zeros_per_block = min(input_sr, output_sr) * cutoff_ratio
|
||||
|
||||
# The convolutional kernel size will be n = (blocks_per_side*2 + 1),
|
||||
# i.e. we add that many blocks on each side of the central block. The
|
||||
# window radius (defined as distance from center to edge)
|
||||
# is `blocks_per_side` blocks. This ensures that each sample in the
|
||||
# central block can "see" all the samples in its window.
|
||||
#
|
||||
# Assuming the following division is not exact, adding 1
|
||||
# will have the same effect as rounding up.
|
||||
# blocks_per_side = 1 + int(num_zeros / zeros_per_block)
|
||||
blocks_per_side = int(np.ceil(num_zeros / zeros_per_block))
|
||||
|
||||
kernel_width = 2 * blocks_per_side + 1
|
||||
|
||||
# We want the weights as used by torch's conv1d code; format is
|
||||
# (out_channels, in_channels, kernel_width)
|
||||
# https://pytorch.org/docs/stable/nn.functional.html
|
||||
weights = torch.tensor((output_sr, input_sr, kernel_width), dtype=dtype)
|
||||
|
||||
# Computations involving time will be in units of 1 block. Actually this
|
||||
# is the same as the `canonical` time axis since each block has input_sr
|
||||
# input samples, so it would be one of whatever time unit we are using
|
||||
window_radius_in_blocks = blocks_per_side
|
||||
|
||||
# The `times` below will end up being the args to the sinc function.
|
||||
# For the shapes of the things below, look at the args to `view`. The terms
|
||||
# below will get expanded to shape (output_sr, input_sr, kernel_width) through
|
||||
# broadcasting
|
||||
# We want it so that, assuming input_sr == output_sr, along the diagonal of
|
||||
# the central block we have t == 0.
|
||||
# The signs of the output_sr and input_sr terms need to be opposite. The
|
||||
# sign that the kernel_width term needs to be will depend on whether it's
|
||||
# convolution or correlation, and the logic is tricky.. I will just find
|
||||
# which sign works.
|
||||
|
||||
times = (
|
||||
np.arange(output_sr, dtype=np_dtype).reshape((output_sr, 1, 1)) / output_sr -
|
||||
np.arange(input_sr, dtype=np_dtype).reshape((1, input_sr, 1)) / input_sr -
|
||||
(np.arange(kernel_width, dtype=np_dtype).reshape((1, 1, kernel_width)) - blocks_per_side))
|
||||
|
||||
def hann_window(a):
|
||||
"""
|
||||
hann_window returns the Hann window on [-1,1], which is zero
|
||||
if a < -1 or a > 1, and otherwise 0.5 + 0.5 cos(a*pi).
|
||||
This is applied elementwise to a, which should be a NumPy array.
|
||||
|
||||
The heaviside function returns (a > 0 ? 1 : 0).
|
||||
"""
|
||||
return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi))
|
||||
|
||||
def kaiser_window(a, beta):
|
||||
w = special.i0(beta * np.sqrt(np.clip(1 - ((a - 0.0) / 1.0) ** 2.0, 0.0, 1.0))) / special.i0(beta)
|
||||
return np.heaviside(1 - np.abs(a), 0.0) * w
|
||||
|
||||
# The weights below are a sinc function times a Hann-window function.
|
||||
#
|
||||
# Multiplication by zeros_per_block normalizes the sinc function
|
||||
# (to compensate for scaling on the x-axis), so that the integral is 1.
|
||||
#
|
||||
# Division by input_sr normalizes the input function. Think of the input
|
||||
# as a stream of dirac deltas passing through a low pass filter:
|
||||
# in order to have the same magnitude as the original input function,
|
||||
# we need to divide by the number of those deltas per unit time.
|
||||
if filter == 'hann':
|
||||
weights = (np.sinc(times * zeros_per_block)
|
||||
* hann_window(times / window_radius_in_blocks)
|
||||
* zeros_per_block / input_sr)
|
||||
else:
|
||||
weights = (np.sinc(times * zeros_per_block)
|
||||
* kaiser_window(times / window_radius_in_blocks, beta)
|
||||
* zeros_per_block / input_sr)
|
||||
|
||||
self.input_sr = input_sr
|
||||
self.output_sr = output_sr
|
||||
|
||||
# weights has dim (output_sr, input_sr, kernel_width).
|
||||
# If output_sr == 1, we can fold the input_sr into the
|
||||
# kernel_width (i.e. have just 1 input channel); this will make the
|
||||
# convolution faster and avoid unnecessary reshaping.
|
||||
|
||||
assert weights.shape == (output_sr, input_sr, kernel_width)
|
||||
if output_sr == 1:
|
||||
self.resample_type = 'integer_downsample'
|
||||
self.padding = input_sr * blocks_per_side
|
||||
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
|
||||
self.weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width)
|
||||
|
||||
elif input_sr == 1:
|
||||
# In this case we'll be doing conv_transpose, so we want the same weights that
|
||||
# we would have if we were *downsampling* by this factor-- i.e. as if input_sr,
|
||||
# output_sr had been swapped.
|
||||
self.resample_type = 'integer_upsample'
|
||||
self.padding = output_sr * blocks_per_side
|
||||
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
|
||||
self.weights = weights.flip(2).transpose(0, 2).contiguous().view(1, 1, output_sr * kernel_width)
|
||||
else:
|
||||
self.resample_type = 'general'
|
||||
self.reshaped = False
|
||||
self.padding = blocks_per_side
|
||||
self.weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
|
||||
|
||||
self.weights = torch.nn.Parameter(self.weights, requires_grad=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, data):
|
||||
"""
|
||||
Resample the data
|
||||
|
||||
Args:
|
||||
input: a torch.Tensor with the same dtype as was passed to the
|
||||
constructor.
|
||||
There must be 2 axes, interpreted as (minibatch_size, sequence_length)...
|
||||
the minibatch_size may in practice be the number of channels.
|
||||
|
||||
Return: Returns a torch.Tensor with the same dtype as the input, and
|
||||
dimension (minibatch_size, (sequence_length//input_sr)*output_sr),
|
||||
where input_sr and output_sr are the corresponding constructor args,
|
||||
modified to remove any common factors.
|
||||
"""
|
||||
if self.resample_type == 'trivial':
|
||||
return data
|
||||
elif self.resample_type == 'integer_downsample':
|
||||
(minibatch_size, seq_len) = data.shape
|
||||
# will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1
|
||||
data = data.unsqueeze(1)
|
||||
data = torch.nn.functional.conv1d(data,
|
||||
self.weights,
|
||||
stride=self.input_sr,
|
||||
padding=self.padding)
|
||||
# shape will be (minibatch_size, out_channels = 1, seq_len);
|
||||
# return as (minibatch_size, seq_len)
|
||||
return data.squeeze(1)
|
||||
|
||||
elif self.resample_type == 'integer_upsample':
|
||||
data = data.unsqueeze(1)
|
||||
data = torch.nn.functional.conv_transpose1d(data,
|
||||
self.weights,
|
||||
stride=self.output_sr,
|
||||
padding=self.padding)
|
||||
|
||||
return data.squeeze(1)
|
||||
else:
|
||||
assert self.resample_type == 'general'
|
||||
(minibatch_size, seq_len) = data.shape
|
||||
num_blocks = seq_len // self.input_sr
|
||||
if num_blocks == 0:
|
||||
# TODO: pad with zeros.
|
||||
raise RuntimeError("Signal is too short to resample")
|
||||
# data = data[:, 0:(num_blocks*self.input_sr)] # Truncate input
|
||||
data = data[:, 0:(num_blocks * self.input_sr)].view(minibatch_size, num_blocks, self.input_sr)
|
||||
|
||||
# Torch's conv1d expects input data with shape (minibatch, in_channels, time_steps), so transpose
|
||||
data = data.transpose(1, 2)
|
||||
|
||||
data = torch.nn.functional.conv1d(data, self.weights,
|
||||
padding=self.padding)
|
||||
|
||||
assert data.shape == (minibatch_size, self.output_sr, num_blocks)
|
||||
return data.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr)
|
Loading…
Reference in New Issue
Block a user