Add gpt_tts dataset and implement inference

- Adds a script which preprocesses quantized mels given a DVAE
- Adds a dataset which can consume preprocessed qmels
- Reworks GPT TTS to consume the outputs of that dataset (removes logic to add padding and start/end tokens)
- Adds inference to gpt_tts
This commit is contained in:
James Betker 2021-08-04 00:44:04 -06:00
parent 4c98b9703f
commit d9936df363
6 changed files with 205 additions and 64 deletions

View File

@ -70,6 +70,10 @@ def create_dataset(dataset_opt, return_collate=False):
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
collate = C(dataset_opt.n_frames_per_step)
elif mode == 'gpt_tts':
from data.audio.gpt_tts_dataset import GptTtsDataset as D
from data.audio.gpt_tts_dataset import GptTtsCollater as C
collate = C(dataset_opt)
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

View File

@ -0,0 +1,104 @@
import os
import random
import numpy as np
import torch
import torch.utils.data
from torch import LongTensor
from tqdm import tqdm
import models.tacotron2.layers as layers
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
from models.tacotron2.text import text_to_sequence
from utils.util import opt_get
from models.tacotron2.text import symbols
import torch.nn.functional as F
class GptTtsDataset(torch.utils.data.Dataset):
NUMBER_SYMBOLS = len(symbols)+3
TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3])
TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2])
def __init__(self, opt):
self.path = os.path.dirname(opt['path'])
self.audiopaths_and_text = load_filepaths_and_text(opt['path'])
self.text_cleaners=['english_cleaners']
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3])
self.MEL_STOP_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-2])
def __getitem__(self, index):
# Fetch text and add start/stop tokens.
audiopath_and_text = self.audiopaths_and_text[index]
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
text = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0)
# Fetch quantized MELs
quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth'
filename = os.path.join(self.path, quant_path)
qmel = torch.load(filename)
qmel = torch.cat([self.MEL_START_TOKEN, qmel, self.MEL_STOP_TOKEN])
return text, qmel, audiopath
def __len__(self):
return len(self.audiopaths_and_text)
class GptTtsCollater():
NUMBER_SYMBOLS = len(symbols)+3
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
def __init__(self, opt):
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1
def __call__(self, batch):
text_lens = [len(x[0]) for x in batch]
max_text_len = max(text_lens)
mel_lens = [len(x[1]) for x in batch]
max_mel_len = max(mel_lens)
texts = []
qmels = []
for b in batch:
text, qmel, _ = b
texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN))
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
filenames = [j[2] for j in batch]
return {
'padded_text': torch.stack(texts),
'input_lengths': LongTensor(text_lens),
'padded_qmel': torch.stack(qmels),
'output_lengths': LongTensor(mel_lens),
'filenames': filenames
}
if __name__ == '__main__':
params = {
'mode': 'gpt_tts',
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
'phase': 'train',
'n_workers': 0,
'batch_size': 16,
'mel_vocab_size': 512,
}
from data import create_dataset, create_dataloader
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = []
max_text = 0
max_mel = 0
for b in tqdm(dl):
max_mel = max(max_mel, b['padded_qmel'].shape[2])
max_text = max(max_text, b['padded_text'].shape[1])
m=torch.stack(m)
print(m.mean(), m.std())

View File

@ -41,7 +41,7 @@ class TextMelLoader(torch.utils.data.Dataset):
audiopath = os.path.join(self.path, audiopath)
text = self.get_text(text)
mel = self.get_mel(audiopath)
return (text, mel)
return (text, mel, audiopath_and_text[0])
def get_mel(self, filename):
if not self.load_mel_from_disk:
@ -88,7 +88,7 @@ class TextMelCollate():
"""Collate's training batch from normalized text and mel-spectrogram
PARAMS
------
batch: [text_normalized, mel_normalized]
batch: [text_normalized, mel_normalized, filename]
"""
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
@ -121,12 +121,15 @@ class TextMelCollate():
gate_padded[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)
filenames = [j[2] for j in batch]
return {
'padded_text': text_padded,
'input_lengths': input_lengths,
'padded_mel': mel_padded,
'padded_gate': gate_padded,
'output_lengths': output_lengths,
'filenames': filenames
}

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from torch import LongTensor
from tqdm import tqdm
from models.arch_util import ConvGnSilu
@ -45,26 +46,6 @@ class GptTts(nn.Module):
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae.
# Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level.
batch_range = torch.arange(0, text_inputs.shape[0])
text_inputs = F.pad(text_inputs, (0,1))
text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (0,1))
mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
output_lengths = output_lengths + 1
# Add the start tokens to the beginnings of the texts and mels.
text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN)
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN)
output_lengths = output_lengths + 1
# Add padding as well. This also should realistically be done at the dataloader level.
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN)
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN)
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_emb = self.mel_embedding(mel_targets)
@ -81,62 +62,43 @@ class GptTts(nn.Module):
# Compute loss
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
pad_loss_reduction_factor = .01
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
# Fix up mel_logits so it can go into a VAE decoder as well.
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_codes = mel_codes[:,1:]
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0)
mel_codes = mel_codes[:,:-1]
mel_codes = mel_codes[:,1:-1] # Strip off first and last tokens (START+STOP were added by the dataloader)
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:-1], 0)
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
return loss_text.mean(), loss_mel.mean(), mel_codes
def inference(self, text_inputs, mel_guide):
MEL_HEAD_EXPANSION = 2
GATE_THRESHOLD = .95
def inference(self, text_inputs):
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags
b,s,c = text_emb.shape
emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1),], dim=1)
#self.test_guide(mel_guide)], dim=1)
completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool)
output = None
for i in tqdm(range(self.max_mel_frames)):
enc = self.gpt(emb, text_emb.shape[1])
inferred = enc[:,s:,:].permute(0,2,1)
# Create output frames.
inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:]
inferred_mel_frame = inferred_mel_frame * (~completed).float().view(b,1,1)
if output is None:
output = inferred_mel_frame
else:
output = torch.cat([output, inferred_mel_frame], dim=2)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
# Test termination condition
gate = F.sigmoid(self.gate_head(inferred)).max(dim=-1).values # TODO: accept single-frame terminations.
completed = completed.logical_or((gate > GATE_THRESHOLD).squeeze(1)) # This comprises a latch - but that may not be wise.
if torch.all(completed):
break
mel_seq = [self.MEL_START_TOKEN, 0]
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(LongTensor(mel_seq, device=text_inputs.device))
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_seq.shape[1], device=mel_seq.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_seq[-1] = mel_codes[-1]
mel_seq.append(0)
# Apply inferred mel_frames to emb for next pass.
mel_emb = self.mel_encoder(output).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1),
mel_emb], dim=1)
if i == self.max_mel_frames//2:
print("Warning! Inference hit mel frame cap without encountering a stop token.")
break
if len(mel_seq) >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
return output
return mel_seq[:-1]
@register_model

View File

@ -223,7 +223,7 @@ class VQVAE(nn.Module):
quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_t = quant_t.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1))
diff_t = diff_t.unsqueeze(0)
dec_t = checkpoint(self.dec_t, quant_t)
@ -231,7 +231,7 @@ class VQVAE(nn.Module):
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_b = quant_b.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1))
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b

View File

@ -0,0 +1,68 @@
import os
import os.path as osp
import logging
import random
import argparse
import torchvision
import utils
import utils.options as option
import utils.util as util
from models.waveglow.denoiser import Denoiser
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from scipy.io import wavfile
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/generate_quantized_mels.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt)
outpath = opt['path']['results_root']
os.makedirs(os.path.join(outpath, 'quantized_mels'), exist_ok=True)
for test_loader in test_loaders:
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
tq = tqdm(test_loader)
for data in tq:
with torch.no_grad():
model.feed_data(data, 0)
model.test()
wavfiles = data['filenames']
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
for i, wavfile in enumerate(wavfiles):
qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth'
torch.save(quantized[i], os.path.join(outpath, qmelfile))