From d9936df3634b058c0a4625c45d6cb63a5fb88b20 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 4 Aug 2021 00:44:04 -0600 Subject: [PATCH] Add gpt_tts dataset and implement inference - Adds a script which preprocesses quantized mels given a DVAE - Adds a dataset which can consume preprocessed qmels - Reworks GPT TTS to consume the outputs of that dataset (removes logic to add padding and start/end tokens) - Adds inference to gpt_tts --- codes/data/__init__.py | 4 + codes/data/audio/gpt_tts_dataset.py | 104 ++++++++++++++++++ codes/data/audio/nv_tacotron_dataset.py | 7 +- codes/models/gpt_voice/gpt_tts.py | 82 ++++---------- codes/models/vqvae/vqvae.py | 4 +- .../scripts/audio/generate_quantized_mels.py | 68 ++++++++++++ 6 files changed, 205 insertions(+), 64 deletions(-) create mode 100644 codes/data/audio/gpt_tts_dataset.py create mode 100644 codes/scripts/audio/generate_quantized_mels.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 0cd5f8af..4130c525 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -70,6 +70,10 @@ def create_dataset(dataset_opt, return_collate=False): default_params.update(dataset_opt) dataset_opt = munchify(default_params) collate = C(dataset_opt.n_frames_per_step) + elif mode == 'gpt_tts': + from data.audio.gpt_tts_dataset import GptTtsDataset as D + from data.audio.gpt_tts_dataset import GptTtsCollater as C + collate = C(dataset_opt) else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/audio/gpt_tts_dataset.py b/codes/data/audio/gpt_tts_dataset.py new file mode 100644 index 00000000..d871995b --- /dev/null +++ b/codes/data/audio/gpt_tts_dataset.py @@ -0,0 +1,104 @@ +import os +import random +import numpy as np +import torch +import torch.utils.data +from torch import LongTensor +from tqdm import tqdm + +import models.tacotron2.layers as layers +from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text + +from models.tacotron2.text import text_to_sequence +from utils.util import opt_get +from models.tacotron2.text import symbols +import torch.nn.functional as F + + +class GptTtsDataset(torch.utils.data.Dataset): + NUMBER_SYMBOLS = len(symbols)+3 + TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3]) + TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2]) + + def __init__(self, opt): + self.path = os.path.dirname(opt['path']) + self.audiopaths_and_text = load_filepaths_and_text(opt['path']) + self.text_cleaners=['english_cleaners'] + + self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3 + self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3]) + self.MEL_STOP_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-2]) + + def __getitem__(self, index): + # Fetch text and add start/stop tokens. + audiopath_and_text = self.audiopaths_and_text[index] + audiopath, text = audiopath_and_text[0], audiopath_and_text[1] + text = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) + text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0) + + # Fetch quantized MELs + quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth' + filename = os.path.join(self.path, quant_path) + qmel = torch.load(filename) + qmel = torch.cat([self.MEL_START_TOKEN, qmel, self.MEL_STOP_TOKEN]) + + return text, qmel, audiopath + + def __len__(self): + return len(self.audiopaths_and_text) + + +class GptTtsCollater(): + NUMBER_SYMBOLS = len(symbols)+3 + TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1 + + def __init__(self, opt): + + self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3 + self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1 + + def __call__(self, batch): + text_lens = [len(x[0]) for x in batch] + max_text_len = max(text_lens) + mel_lens = [len(x[1]) for x in batch] + max_mel_len = max(mel_lens) + texts = [] + qmels = [] + for b in batch: + text, qmel, _ = b + texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN)) + qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN)) + + filenames = [j[2] for j in batch] + + return { + 'padded_text': torch.stack(texts), + 'input_lengths': LongTensor(text_lens), + 'padded_qmel': torch.stack(qmels), + 'output_lengths': LongTensor(mel_lens), + 'filenames': filenames + } + + +if __name__ == '__main__': + params = { + 'mode': 'gpt_tts', + 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', + 'phase': 'train', + 'n_workers': 0, + 'batch_size': 16, + 'mel_vocab_size': 512, + } + from data import create_dataset, create_dataloader + + ds, c = create_dataset(params, return_collate=True) + dl = create_dataloader(ds, params, collate_fn=c) + i = 0 + m = [] + max_text = 0 + max_mel = 0 + for b in tqdm(dl): + max_mel = max(max_mel, b['padded_qmel'].shape[2]) + max_text = max(max_text, b['padded_text'].shape[1]) + m=torch.stack(m) + print(m.mean(), m.std()) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 4950839c..78438200 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -41,7 +41,7 @@ class TextMelLoader(torch.utils.data.Dataset): audiopath = os.path.join(self.path, audiopath) text = self.get_text(text) mel = self.get_mel(audiopath) - return (text, mel) + return (text, mel, audiopath_and_text[0]) def get_mel(self, filename): if not self.load_mel_from_disk: @@ -88,7 +88,7 @@ class TextMelCollate(): """Collate's training batch from normalized text and mel-spectrogram PARAMS ------ - batch: [text_normalized, mel_normalized] + batch: [text_normalized, mel_normalized, filename] """ # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( @@ -121,12 +121,15 @@ class TextMelCollate(): gate_padded[i, mel.size(1)-1:] = 1 output_lengths[i] = mel.size(1) + filenames = [j[2] for j in batch] + return { 'padded_text': text_padded, 'input_lengths': input_lengths, 'padded_mel': mel_padded, 'padded_gate': gate_padded, 'output_lengths': output_lengths, + 'filenames': filenames } diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index ed06fd24..eb7ad706 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from munch import munchify +from torch import LongTensor from tqdm import tqdm from models.arch_util import ConvGnSilu @@ -45,26 +46,6 @@ class GptTts(nn.Module): self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE) def forward(self, text_inputs, text_lengths, mel_targets, output_lengths): - output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae. - # Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level. - batch_range = torch.arange(0, text_inputs.shape[0]) - text_inputs = F.pad(text_inputs, (0,1)) - text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device)) - text_lengths = text_lengths + 1 - mel_targets = F.pad(mel_targets, (0,1)) - mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device)) - output_lengths = output_lengths + 1 - # Add the start tokens to the beginnings of the texts and mels. - text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN) - text_lengths = text_lengths + 1 - mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN) - output_lengths = output_lengths + 1 - # Add padding as well. This also should realistically be done at the dataloader level. - text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1]) - text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN) - mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1]) - mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN) - text_emb = self.text_embedding(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) mel_emb = self.mel_embedding(mel_targets) @@ -81,62 +62,43 @@ class GptTts(nn.Module): # Compute loss loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none') loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none') + # Apply a reduction factor across MEL_PAD and TEXT_PAD tokens. pad_loss_reduction_factor = .01 + text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1]) + mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1]) loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor) loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor) # Fix up mel_logits so it can go into a VAE decoder as well. mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1) - mel_codes = mel_codes[:,1:] - mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0) - mel_codes = mel_codes[:,:-1] + mel_codes = mel_codes[:,1:-1] # Strip off first and last tokens (START+STOP were added by the dataloader) + mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:-1], 0) extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD mel_codes = mel_codes * extra_mask return loss_text.mean(), loss_mel.mean(), mel_codes - def inference(self, text_inputs, mel_guide): - MEL_HEAD_EXPANSION = 2 - GATE_THRESHOLD = .95 - + def inference(self, text_inputs): text_emb = self.text_embedding(text_inputs) - text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1]) - text_emb = text_emb + self.text_tags - b,s,c = text_emb.shape - emb = torch.cat([text_emb, - self.separator.repeat(text_emb.shape[0],1,1),], dim=1) - #self.test_guide(mel_guide)], dim=1) - completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool) - output = None - for i in tqdm(range(self.max_mel_frames)): - enc = self.gpt(emb, text_emb.shape[1]) - inferred = enc[:,s:,:].permute(0,2,1) - # Create output frames. - inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:] - inferred_mel_frame = inferred_mel_frame * (~completed).float().view(b,1,1) - if output is None: - output = inferred_mel_frame - else: - output = torch.cat([output, inferred_mel_frame], dim=2) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - # Test termination condition - gate = F.sigmoid(self.gate_head(inferred)).max(dim=-1).values # TODO: accept single-frame terminations. - completed = completed.logical_or((gate > GATE_THRESHOLD).squeeze(1)) # This comprises a latch - but that may not be wise. - if torch.all(completed): - break + mel_seq = [self.MEL_START_TOKEN, 0] + while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames: + mel_emb = self.mel_embedding(LongTensor(mel_seq, device=text_inputs.device)) + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_seq.shape[1], device=mel_seq.device)) + emb = torch.cat([text_emb, mel_emb], dim=1) + enc = self.gpt(emb) + mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) + mel_logits = self.mel_head(mel_logits) + mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1) + mel_seq[-1] = mel_codes[-1] + mel_seq.append(0) - # Apply inferred mel_frames to emb for next pass. - mel_emb = self.mel_encoder(output).permute(0,2,1) - mel_emb = mel_emb + self.audio_tags - emb = torch.cat([text_emb, - self.separator.repeat(text_emb.shape[0],1,1), - mel_emb], dim=1) - if i == self.max_mel_frames//2: - print("Warning! Inference hit mel frame cap without encountering a stop token.") - break + if len(mel_seq) >= self.max_mel_frames: + print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - return output + return mel_seq[:-1] @register_model diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 75f91b8b..5b7cc09c 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -223,7 +223,7 @@ class VQVAE(nn.Module): quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) + quant_t = quant_t.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1)) diff_t = diff_t.unsqueeze(0) dec_t = checkpoint(self.dec_t, quant_t) @@ -231,7 +231,7 @@ class VQVAE(nn.Module): quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) + quant_b = quant_b.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1)) diff_b = diff_b.unsqueeze(0) return quant_t, quant_b, diff_t + diff_b, id_t, id_b diff --git a/codes/scripts/audio/generate_quantized_mels.py b/codes/scripts/audio/generate_quantized_mels.py new file mode 100644 index 00000000..6adf7ea1 --- /dev/null +++ b/codes/scripts/audio/generate_quantized_mels.py @@ -0,0 +1,68 @@ +import os +import os.path as osp +import logging +import random +import argparse + +import torchvision + +import utils +import utils.options as option +import utils.util as util +from models.waveglow.denoiser import Denoiser +from trainer.ExtensibleTrainer import ExtensibleTrainer +from data import create_dataset, create_dataloader +from tqdm import tqdm +import torch +import numpy as np +from scipy.io import wavfile + +if __name__ == "__main__": + # Set seeds + torch.manual_seed(5555) + random.seed(5555) + np.random.seed(5555) + + #### options + torch.backends.cudnn.benchmark = True + want_metrics = False + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/generate_quantized_mels.yml') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + utils.util.loaded_options = opt + + util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) + test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) + logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + + model = ExtensibleTrainer(opt) + + outpath = opt['path']['results_root'] + os.makedirs(os.path.join(outpath, 'quantized_mels'), exist_ok=True) + for test_loader in test_loaders: + dataset_dir = opt['path']['results_root'] + util.mkdir(dataset_dir) + + tq = tqdm(test_loader) + for data in tq: + with torch.no_grad(): + model.feed_data(data, 0) + model.test() + + wavfiles = data['filenames'] + quantized = model.eval_state[opt['eval']['quantized_mels']][0] + for i, wavfile in enumerate(wavfiles): + qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth' + torch.save(quantized[i], os.path.join(outpath, qmelfile))