Add gpt_tts dataset and implement inference
- Adds a script which preprocesses quantized mels given a DVAE - Adds a dataset which can consume preprocessed qmels - Reworks GPT TTS to consume the outputs of that dataset (removes logic to add padding and start/end tokens) - Adds inference to gpt_tts
This commit is contained in:
parent
4c98b9703f
commit
d9936df363
|
@ -70,6 +70,10 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
collate = C(dataset_opt.n_frames_per_step)
|
||||
elif mode == 'gpt_tts':
|
||||
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||
collate = C(dataset_opt)
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
104
codes/data/audio/gpt_tts_dataset.py
Normal file
104
codes/data/audio/gpt_tts_dataset.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.tacotron2.layers as layers
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
from models.tacotron2.text import symbols
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GptTtsDataset(torch.utils.data.Dataset):
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3])
|
||||
TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2])
|
||||
|
||||
def __init__(self, opt):
|
||||
self.path = os.path.dirname(opt['path'])
|
||||
self.audiopaths_and_text = load_filepaths_and_text(opt['path'])
|
||||
self.text_cleaners=['english_cleaners']
|
||||
|
||||
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
|
||||
self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3])
|
||||
self.MEL_STOP_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-2])
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Fetch text and add start/stop tokens.
|
||||
audiopath_and_text = self.audiopaths_and_text[index]
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
text = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
||||
text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0)
|
||||
|
||||
# Fetch quantized MELs
|
||||
quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth'
|
||||
filename = os.path.join(self.path, quant_path)
|
||||
qmel = torch.load(filename)
|
||||
qmel = torch.cat([self.MEL_START_TOKEN, qmel, self.MEL_STOP_TOKEN])
|
||||
|
||||
return text, qmel, audiopath
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class GptTtsCollater():
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
|
||||
|
||||
def __init__(self, opt):
|
||||
|
||||
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
|
||||
self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1
|
||||
|
||||
def __call__(self, batch):
|
||||
text_lens = [len(x[0]) for x in batch]
|
||||
max_text_len = max(text_lens)
|
||||
mel_lens = [len(x[1]) for x in batch]
|
||||
max_mel_len = max(mel_lens)
|
||||
texts = []
|
||||
qmels = []
|
||||
for b in batch:
|
||||
text, qmel, _ = b
|
||||
texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN))
|
||||
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
|
||||
|
||||
filenames = [j[2] for j in batch]
|
||||
|
||||
return {
|
||||
'padded_text': torch.stack(texts),
|
||||
'input_lengths': LongTensor(text_lens),
|
||||
'padded_qmel': torch.stack(qmels),
|
||||
'output_lengths': LongTensor(mel_lens),
|
||||
'filenames': filenames
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'gpt_tts',
|
||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
'mel_vocab_size': 512,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
i = 0
|
||||
m = []
|
||||
max_text = 0
|
||||
max_mel = 0
|
||||
for b in tqdm(dl):
|
||||
max_mel = max(max_mel, b['padded_qmel'].shape[2])
|
||||
max_text = max(max_text, b['padded_text'].shape[1])
|
||||
m=torch.stack(m)
|
||||
print(m.mean(), m.std())
|
|
@ -41,7 +41,7 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
audiopath = os.path.join(self.path, audiopath)
|
||||
text = self.get_text(text)
|
||||
mel = self.get_mel(audiopath)
|
||||
return (text, mel)
|
||||
return (text, mel, audiopath_and_text[0])
|
||||
|
||||
def get_mel(self, filename):
|
||||
if not self.load_mel_from_disk:
|
||||
|
@ -88,7 +88,7 @@ class TextMelCollate():
|
|||
"""Collate's training batch from normalized text and mel-spectrogram
|
||||
PARAMS
|
||||
------
|
||||
batch: [text_normalized, mel_normalized]
|
||||
batch: [text_normalized, mel_normalized, filename]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
|
@ -121,12 +121,15 @@ class TextMelCollate():
|
|||
gate_padded[i, mel.size(1)-1:] = 1
|
||||
output_lengths[i] = mel.size(1)
|
||||
|
||||
filenames = [j[2] for j in batch]
|
||||
|
||||
return {
|
||||
'padded_text': text_padded,
|
||||
'input_lengths': input_lengths,
|
||||
'padded_mel': mel_padded,
|
||||
'padded_gate': gate_padded,
|
||||
'output_lengths': output_lengths,
|
||||
'filenames': filenames
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.arch_util import ConvGnSilu
|
||||
|
@ -45,26 +46,6 @@ class GptTts(nn.Module):
|
|||
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
|
||||
|
||||
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
|
||||
output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae.
|
||||
# Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level.
|
||||
batch_range = torch.arange(0, text_inputs.shape[0])
|
||||
text_inputs = F.pad(text_inputs, (0,1))
|
||||
text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
|
||||
text_lengths = text_lengths + 1
|
||||
mel_targets = F.pad(mel_targets, (0,1))
|
||||
mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
|
||||
output_lengths = output_lengths + 1
|
||||
# Add the start tokens to the beginnings of the texts and mels.
|
||||
text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN)
|
||||
text_lengths = text_lengths + 1
|
||||
mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN)
|
||||
output_lengths = output_lengths + 1
|
||||
# Add padding as well. This also should realistically be done at the dataloader level.
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
|
||||
text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN)
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
|
||||
mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN)
|
||||
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
mel_emb = self.mel_embedding(mel_targets)
|
||||
|
@ -81,62 +62,43 @@ class GptTts(nn.Module):
|
|||
# Compute loss
|
||||
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
|
||||
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
|
||||
|
||||
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
|
||||
pad_loss_reduction_factor = .01
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
|
||||
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
|
||||
# Fix up mel_logits so it can go into a VAE decoder as well.
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_codes = mel_codes[:,1:]
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0)
|
||||
mel_codes = mel_codes[:,:-1]
|
||||
mel_codes = mel_codes[:,1:-1] # Strip off first and last tokens (START+STOP were added by the dataloader)
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:-1], 0)
|
||||
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
|
||||
mel_codes = mel_codes * extra_mask
|
||||
|
||||
return loss_text.mean(), loss_mel.mean(), mel_codes
|
||||
|
||||
def inference(self, text_inputs, mel_guide):
|
||||
MEL_HEAD_EXPANSION = 2
|
||||
GATE_THRESHOLD = .95
|
||||
|
||||
def inference(self, text_inputs):
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
|
||||
text_emb = text_emb + self.text_tags
|
||||
b,s,c = text_emb.shape
|
||||
emb = torch.cat([text_emb,
|
||||
self.separator.repeat(text_emb.shape[0],1,1),], dim=1)
|
||||
#self.test_guide(mel_guide)], dim=1)
|
||||
completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool)
|
||||
output = None
|
||||
for i in tqdm(range(self.max_mel_frames)):
|
||||
enc = self.gpt(emb, text_emb.shape[1])
|
||||
inferred = enc[:,s:,:].permute(0,2,1)
|
||||
# Create output frames.
|
||||
inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:]
|
||||
inferred_mel_frame = inferred_mel_frame * (~completed).float().view(b,1,1)
|
||||
if output is None:
|
||||
output = inferred_mel_frame
|
||||
else:
|
||||
output = torch.cat([output, inferred_mel_frame], dim=2)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
|
||||
# Test termination condition
|
||||
gate = F.sigmoid(self.gate_head(inferred)).max(dim=-1).values # TODO: accept single-frame terminations.
|
||||
completed = completed.logical_or((gate > GATE_THRESHOLD).squeeze(1)) # This comprises a latch - but that may not be wise.
|
||||
if torch.all(completed):
|
||||
break
|
||||
mel_seq = [self.MEL_START_TOKEN, 0]
|
||||
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(LongTensor(mel_seq, device=text_inputs.device))
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_seq.shape[1], device=mel_seq.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_seq[-1] = mel_codes[-1]
|
||||
mel_seq.append(0)
|
||||
|
||||
# Apply inferred mel_frames to emb for next pass.
|
||||
mel_emb = self.mel_encoder(output).permute(0,2,1)
|
||||
mel_emb = mel_emb + self.audio_tags
|
||||
emb = torch.cat([text_emb,
|
||||
self.separator.repeat(text_emb.shape[0],1,1),
|
||||
mel_emb], dim=1)
|
||||
if i == self.max_mel_frames//2:
|
||||
print("Warning! Inference hit mel frame cap without encountering a stop token.")
|
||||
break
|
||||
if len(mel_seq) >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
return output
|
||||
return mel_seq[:-1]
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -223,7 +223,7 @@ class VQVAE(nn.Module):
|
|||
|
||||
quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||
quant_t, diff_t, id_t = self.quantize_t(quant_t)
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1))
|
||||
diff_t = diff_t.unsqueeze(0)
|
||||
|
||||
dec_t = checkpoint(self.dec_t, quant_t)
|
||||
|
@ -231,7 +231,7 @@ class VQVAE(nn.Module):
|
|||
|
||||
quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1))
|
||||
quant_b, diff_b, id_b = self.quantize_b(quant_b)
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1))
|
||||
diff_b = diff_b.unsqueeze(0)
|
||||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
|
68
codes/scripts/audio/generate_quantized_mels.py
Normal file
68
codes/scripts/audio/generate_quantized_mels.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import logging
|
||||
import random
|
||||
import argparse
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from models.waveglow.denoiser import Denoiser
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set seeds
|
||||
torch.manual_seed(5555)
|
||||
random.seed(5555)
|
||||
np.random.seed(5555)
|
||||
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/generate_quantized_mels.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||
screen=True, tofile=True)
|
||||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
outpath = opt['path']['results_root']
|
||||
os.makedirs(os.path.join(outpath, 'quantized_mels'), exist_ok=True)
|
||||
for test_loader in test_loaders:
|
||||
dataset_dir = opt['path']['results_root']
|
||||
util.mkdir(dataset_dir)
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
for data in tq:
|
||||
with torch.no_grad():
|
||||
model.feed_data(data, 0)
|
||||
model.test()
|
||||
|
||||
wavfiles = data['filenames']
|
||||
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
|
||||
for i, wavfile in enumerate(wavfiles):
|
||||
qmelfile = wavfile.replace('wavs/', 'quantized_mels/') + '.pth'
|
||||
torch.save(quantized[i], os.path.join(outpath, qmelfile))
|
Loading…
Reference in New Issue
Block a user