Stop dataset - attempt #2

This commit is contained in:
James Betker 2021-08-18 18:29:38 -06:00
parent 17453ccbe8
commit 570ed327ed
9 changed files with 466 additions and 18 deletions

View File

@ -81,6 +81,8 @@ def create_dataset(dataset_opt, return_collate=False):
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
from data.audio.stop_prediction_dataset import StopPredictionDataset as D
elif mode == 'stop_prediction2':
from data.audio.stop_prediction_dataset_2 import StopPredictionDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

View File

@ -0,0 +1,92 @@
import os
import pathlib
import random
from munch import munchify
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
class StopPredictionDataset(Dataset):
def __init__(self, opt):
path = opt['path']
label_compaction = opt_get(opt, ['label_compaction'], 1)
hp = munchify(hparams.create_hparams())
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
self.files = torch.load(cache_path)
else:
print("Building cache..")
self.files = list(pathlib.Path(path).glob('*.wav'))
torch.save(self.files, cache_path)
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
self.mel_length = 2000
self.stft = TacotronSTFT(
hp.filter_length, hp.hop_length, hp.win_length,
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
hp.mel_fmax)
self.label_compaction = label_compaction
def __getitem__(self, index):
audio, _ = load_wav_to_torch(self.files[index])
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
if audio.std() > 1:
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
return None
audio.clip_(-1, 1)
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
# Form labels.
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for s in starts:
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
s = s // (256 * self.label_compaction)
if s >= 2000//self.label_compaction:
continue
labels_start[s] = 1
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for e in ends:
e = e // (256 * self.label_compaction)
if e >= 2000//self.label_compaction:
continue
labels_end[e] = 1
return {
'mels': mels,
'labels_start': labels_start,
'labels_end': labels_end,
}
def __len__(self):
return len(self.files)
if __name__ == '__main__':
opt = {
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
'label_compaction': 4,
}
ds = StopPredictionDataset(opt)
j = 0
for i in tqdm(range(100)):
b = ds[random.randint(0, len(ds))]
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
for i in range(len(end_indices)):
s = start_indices[i].item()*4
e = end_indices[i].item()*4
m = b['mels'][:, s:e]
save_mel_buffer_to_file(m, f'{j}.npy')
j += 1

View File

@ -62,27 +62,26 @@ class GptSegmentor(nn.Module):
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
self.final_norm = nn.LayerNorm(model_dim)
self.start_head = nn.Linear(model_dim, 1)
self.stop_head = nn.Linear(model_dim, 1)
def forward(self, mel_inputs, termination_points=None):
def forward(self, mel_inputs, start_labels=None, end_labels=None):
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
enc = self.gpt(mel_emb)
stop_logits = self.final_norm(enc)
stop_logits = self.stop_head(stop_logits)
if termination_points is not None:
# The MEL gets decimated to 1/4 the size by the encoder, so we need to do the same to the termination points.
termination_points = F.interpolate(termination_points.unsqueeze(1), size=mel_emb.shape[1], mode='area').squeeze()
termination_points = (termination_points > 0).float()
logits = self.final_norm(enc)
stop_logits = self.stop_head(logits)
start_logits = self.start_head(logits)
if start_labels is not None:
# Compute loss
loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), termination_points)
return loss.mean()
start_loss = F.binary_cross_entropy_with_logits(start_logits.squeeze(-1), start_labels.float())
end_loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), end_labels.float())
return start_loss.mean(), end_loss.mean()
else:
return stop_logits
return start_logits, stop_logits

View File

@ -36,7 +36,7 @@ def create_hparams(hparams_string=None, verbose=False):
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
sampling_rate=22050,
filter_length=1024,
hop_length=256,
hop_length=256, # This means a MEL is 1/256th the equivalent audio.
win_length=1024,
n_mel_channels=80,
mel_fmin=0.0,

View File

@ -0,0 +1,99 @@
# Combines all libriTTS WAV->text mappings into a single file
import os
import random
import audio2numpy
import torch
from scipy.io import wavfile
from tqdm import tqdm
from utils.audio_resampler import AudioResampler
def secs_to_frames(secs, sr):
return int(secs*sr)
def get_audio_clip(audio, sr, start, end):
start = secs_to_frames(start, sr)
end = secs_to_frames(end, sr)
assert end > start
if end >= audio.shape[0]:
return None
return audio[start:end]
# Produces an audio clip that would produce a MEL spectrogram of length mel_length by parsing parsed_sentences starting
# at starting_index and moving forwards until the full length is finished.
# Returns:
# On failure, returns tuple: (end_index, None, [], [])
# On success: returns tuple: (end_index, clip, start_points, end_points)
# clip.shape = (<mel_length*256>,)
# start_points = list(ints) where each sentence in the clip starts
# end_points = list(ints) where each sentence in the clip ends
def gather_clip(audio, parsed_sentences, starting_index, sr, mel_length):
audio_length = (mel_length * 256) / sr # This is technically a hyperparameter, but I have no intent of changing the MEL hop length.
starts = []
ends = []
start, end = parsed_sentences[starting_index][4:6]
start = float(start)
end = float(end)
clipstart = max(start - random.random() * 2, 0) # Offset start backwards by up to 2 seconds
clipend = start + audio_length
clip = get_audio_clip(audio, sr, clipstart, clipend)
if clip is not None:
# Fetch the start and endpoints that go along with this clip.
starts.append(secs_to_frames(start-clipstart, sr))
while end < clipend:
ends.append(secs_to_frames(end-clipstart, sr))
starting_index += 1
if starting_index >= len(parsed_sentences):
break
start, end = parsed_sentences[starting_index][4:6]
start = float(start)
end = float(end)
if start < clipend:
starts.append(secs_to_frames(start-clipstart, sr))
return starting_index+1, clip, starts, ends
if __name__ == '__main__':
full_book_root = 'D:\\data\\audio\\libritts\\full_books\\mp3'
libri_root = 'D:\\data\\audio\\libritts\\test-clean'
desired_mel_length = 2000
desired_audio_sample_rate = 22050
output_dir = 'D:\\data\\audio\\libritts\\stop_dataset_eval'
os.makedirs(output_dir, exist_ok=True)
j = 0
readers = os.listdir(libri_root)
for it, reader_dir in enumerate(tqdm(readers)):
#if it <= 145: # Hey idiot! If you change this, change j too!
# continue
reader = os.path.join(libri_root, reader_dir)
if not os.path.isdir(reader):
continue
for chapter_dir in os.listdir(reader):
chapter = os.path.join(reader, chapter_dir)
if not os.path.isdir(chapter):
continue
id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}'
book_file = os.path.join(chapter, f'{id}.book.tsv')
if not os.path.exists(book_file):
continue
with open(book_file, encoding='utf-8') as f:
full_chapter, sr = audio2numpy.open_audio(os.path.join(full_book_root, reader_dir, chapter_dir, f'{chapter_dir}.mp3'))
full_chapter = torch.tensor(full_chapter)
if len(full_chapter.shape) > 1:
full_chapter = full_chapter[:, 0] # Only use mono-audio.
resampler = AudioResampler(sr, desired_audio_sample_rate, dtype=torch.float)
full_chapter = resampler(full_chapter.unsqueeze(0)).squeeze(0)
parsed_sentences = [line.strip().split('\t') for line in f]
i = 0
while i < len(parsed_sentences):
i, clip, ns, ne = gather_clip(full_chapter, parsed_sentences, i, desired_audio_sample_rate, desired_mel_length)
if clip is not None:
wavfile.write(os.path.join(output_dir, f'{j}.wav'), desired_audio_sample_rate, clip.cpu().numpy())
torch.save((ns,ne), os.path.join(output_dir, f'{j}_se.pth'))
j += 1

View File

@ -1,6 +1,9 @@
import pathlib
import numpy
import torch
from scipy.io import wavfile
from tqdm import tqdm
from models.waveglow.waveglow import WaveGlow
@ -21,8 +24,12 @@ class Vocoder:
if __name__ == '__main__':
inp = '3.npy'
mel = torch.tensor(numpy.load(inp)).to('cuda')
vocoder = Vocoder()
wav = vocoder.transform_mel_to_audio(mel)
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())
path = 'data/audio'
files = list(pathlib.Path(path).glob('*.npy'))
for inp in tqdm(files):
inp = str(inp)
mel = torch.tensor(numpy.load(inp)).to('cuda')
vocoder = Vocoder()
wav = vocoder.transform_mel_to_audio(mel)
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())

View File

@ -282,7 +282,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -0,0 +1,249 @@
import torch
import numpy as np
from scipy import special
# Courtesy of https://www.kaggle.com/smallyellowduck/fast-audio-resampling-layer-in-pytorch
class AudioResampler(torch.nn.Module):
"""
Efficiently resample audio signals
This module is much faster than resampling with librosa because it exploits pytorch's efficient conv1d operations
This module is also faster than the existing pytorch resample function in
https://github.com/pytorch/audio/blob/b6a61c3f7d0267c77f8626167cc1eda0335f2753/torchaudio/compliance/kaldi.py#L892
Based on
https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
with improvements to include additional filter types and input parameters that align with the librosa api
"""
def __init__(self,
input_sr, output_sr, dtype,
num_zeros=64, cutoff_ratio=0.95, filter='kaiser', beta=14.0):
super().__init__() # init the base class
"""
This creates an object that can apply a symmetric FIR filter
based on torch.nn.functional.conv1d.
Args:
input_sr: The input sampling rate, AS AN INTEGER..
does not have to be the real sampling rate but should
have the correct ratio with output_sr.
output_sr: The output sampling rate, AS AN INTEGER.
It is the ratio with the input sampling rate that is
important here.
dtype: The torch dtype to use for computations (would be preferrable to
set things up so passing the dtype isn't necessary)
num_zeros: The number of zeros per side in the (sinc*hanning-window)
filter function. More is more accurate, but 64 is already
quite a lot. The kernel size is 2*num_zeros + 1.
cutoff_ratio: The filter rolloff point as a fraction of the
Nyquist frequency.
filter: one of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']
beta: parameter for 'kaiser' filter
You can think of this algorithm as dividing up the signals
(input,output) into blocks where there are `input_sr` input
samples and `output_sr` output samples. Then we treat it
using convolutional code, imagining there are `input_sr`
input channels and `output_sr` output channels per time step.
"""
assert isinstance(input_sr, int) and isinstance(output_sr, int)
if input_sr == output_sr:
self.resample_type = 'trivial'
return
def gcd(a, b):
""" Return the greatest common divisor of a and b"""
assert isinstance(a, int) and isinstance(b, int)
if b == 0:
return a
else:
return gcd(b, a % b)
d = gcd(input_sr, output_sr)
input_sr, output_sr = input_sr // d, output_sr // d
assert dtype in [torch.float32, torch.float64]
assert num_zeros > 3 # a reasonable bare minimum
np_dtype = np.float32 if dtype == torch.float32 else np.float64
assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast']
if filter == 'kaiser_best':
num_zeros = 64
beta = 14.769656459379492
cutoff_ratio = 0.9475937167399596
filter = 'kaiser'
elif filter == 'kaiser_fast':
num_zeros = 16
beta = 8.555504641634386
cutoff_ratio = 0.85
filter = 'kaiser'
# Define one 'block' of samples `input_sr` input samples
# and `output_sr` output samples. We can divide up
# the samples into these blocks and have the blocks be
# in correspondence.
# The sinc function will have, on average, `zeros_per_block`
# zeros per block.
zeros_per_block = min(input_sr, output_sr) * cutoff_ratio
# The convolutional kernel size will be n = (blocks_per_side*2 + 1),
# i.e. we add that many blocks on each side of the central block. The
# window radius (defined as distance from center to edge)
# is `blocks_per_side` blocks. This ensures that each sample in the
# central block can "see" all the samples in its window.
#
# Assuming the following division is not exact, adding 1
# will have the same effect as rounding up.
# blocks_per_side = 1 + int(num_zeros / zeros_per_block)
blocks_per_side = int(np.ceil(num_zeros / zeros_per_block))
kernel_width = 2 * blocks_per_side + 1
# We want the weights as used by torch's conv1d code; format is
# (out_channels, in_channels, kernel_width)
# https://pytorch.org/docs/stable/nn.functional.html
weights = torch.tensor((output_sr, input_sr, kernel_width), dtype=dtype)
# Computations involving time will be in units of 1 block. Actually this
# is the same as the `canonical` time axis since each block has input_sr
# input samples, so it would be one of whatever time unit we are using
window_radius_in_blocks = blocks_per_side
# The `times` below will end up being the args to the sinc function.
# For the shapes of the things below, look at the args to `view`. The terms
# below will get expanded to shape (output_sr, input_sr, kernel_width) through
# broadcasting
# We want it so that, assuming input_sr == output_sr, along the diagonal of
# the central block we have t == 0.
# The signs of the output_sr and input_sr terms need to be opposite. The
# sign that the kernel_width term needs to be will depend on whether it's
# convolution or correlation, and the logic is tricky.. I will just find
# which sign works.
times = (
np.arange(output_sr, dtype=np_dtype).reshape((output_sr, 1, 1)) / output_sr -
np.arange(input_sr, dtype=np_dtype).reshape((1, input_sr, 1)) / input_sr -
(np.arange(kernel_width, dtype=np_dtype).reshape((1, 1, kernel_width)) - blocks_per_side))
def hann_window(a):
"""
hann_window returns the Hann window on [-1,1], which is zero
if a < -1 or a > 1, and otherwise 0.5 + 0.5 cos(a*pi).
This is applied elementwise to a, which should be a NumPy array.
The heaviside function returns (a > 0 ? 1 : 0).
"""
return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi))
def kaiser_window(a, beta):
w = special.i0(beta * np.sqrt(np.clip(1 - ((a - 0.0) / 1.0) ** 2.0, 0.0, 1.0))) / special.i0(beta)
return np.heaviside(1 - np.abs(a), 0.0) * w
# The weights below are a sinc function times a Hann-window function.
#
# Multiplication by zeros_per_block normalizes the sinc function
# (to compensate for scaling on the x-axis), so that the integral is 1.
#
# Division by input_sr normalizes the input function. Think of the input
# as a stream of dirac deltas passing through a low pass filter:
# in order to have the same magnitude as the original input function,
# we need to divide by the number of those deltas per unit time.
if filter == 'hann':
weights = (np.sinc(times * zeros_per_block)
* hann_window(times / window_radius_in_blocks)
* zeros_per_block / input_sr)
else:
weights = (np.sinc(times * zeros_per_block)
* kaiser_window(times / window_radius_in_blocks, beta)
* zeros_per_block / input_sr)
self.input_sr = input_sr
self.output_sr = output_sr
# weights has dim (output_sr, input_sr, kernel_width).
# If output_sr == 1, we can fold the input_sr into the
# kernel_width (i.e. have just 1 input channel); this will make the
# convolution faster and avoid unnecessary reshaping.
assert weights.shape == (output_sr, input_sr, kernel_width)
if output_sr == 1:
self.resample_type = 'integer_downsample'
self.padding = input_sr * blocks_per_side
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
self.weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width)
elif input_sr == 1:
# In this case we'll be doing conv_transpose, so we want the same weights that
# we would have if we were *downsampling* by this factor-- i.e. as if input_sr,
# output_sr had been swapped.
self.resample_type = 'integer_upsample'
self.padding = output_sr * blocks_per_side
weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
self.weights = weights.flip(2).transpose(0, 2).contiguous().view(1, 1, output_sr * kernel_width)
else:
self.resample_type = 'general'
self.reshaped = False
self.padding = blocks_per_side
self.weights = torch.tensor(weights, dtype=dtype, requires_grad=False)
self.weights = torch.nn.Parameter(self.weights, requires_grad=False)
@torch.no_grad()
def forward(self, data):
"""
Resample the data
Args:
input: a torch.Tensor with the same dtype as was passed to the
constructor.
There must be 2 axes, interpreted as (minibatch_size, sequence_length)...
the minibatch_size may in practice be the number of channels.
Return: Returns a torch.Tensor with the same dtype as the input, and
dimension (minibatch_size, (sequence_length//input_sr)*output_sr),
where input_sr and output_sr are the corresponding constructor args,
modified to remove any common factors.
"""
if self.resample_type == 'trivial':
return data
elif self.resample_type == 'integer_downsample':
(minibatch_size, seq_len) = data.shape
# will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1
data = data.unsqueeze(1)
data = torch.nn.functional.conv1d(data,
self.weights,
stride=self.input_sr,
padding=self.padding)
# shape will be (minibatch_size, out_channels = 1, seq_len);
# return as (minibatch_size, seq_len)
return data.squeeze(1)
elif self.resample_type == 'integer_upsample':
data = data.unsqueeze(1)
data = torch.nn.functional.conv_transpose1d(data,
self.weights,
stride=self.output_sr,
padding=self.padding)
return data.squeeze(1)
else:
assert self.resample_type == 'general'
(minibatch_size, seq_len) = data.shape
num_blocks = seq_len // self.input_sr
if num_blocks == 0:
# TODO: pad with zeros.
raise RuntimeError("Signal is too short to resample")
# data = data[:, 0:(num_blocks*self.input_sr)] # Truncate input
data = data[:, 0:(num_blocks * self.input_sr)].view(minibatch_size, num_blocks, self.input_sr)
# Torch's conv1d expects input data with shape (minibatch, in_channels, time_steps), so transpose
data = data.transpose(1, 2)
data = torch.nn.functional.conv1d(data, self.weights,
padding=self.padding)
assert data.shape == (minibatch_size, self.output_sr, num_blocks)
return data.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr)