GPT_asr for inference

This commit is contained in:
James Betker 2021-08-14 14:37:17 -06:00
parent e1bdd3f7c7
commit 007976082b
4 changed files with 112 additions and 30 deletions

View File

@ -62,9 +62,9 @@ class TextMelLoader(torch.utils.data.Dataset):
# separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
audiopath = os.path.join(self.path, audiopath)
text = self.get_text(text)
text_seq = self.get_text(text)
mel = self.get_mel(audiopath)
return (text, mel, audiopath_and_text[0])
return (text_seq, mel, text, audiopath_and_text[0])
def get_mel(self, filename):
if not self.load_mel_from_disk:
@ -106,30 +106,31 @@ class TextMelLoader(torch.utils.data.Dataset):
return text_norm
def __getitem__(self, index):
t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index])
if m is None or \
(self.max_mel_len is not None and m.shape[-1] > self.max_mel_len) or \
(self.max_text_len is not None and t.shape[0] > self.max_text_len):
if m is not None:
print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}")
tseq, mel, text, path = self.get_mel_text_pair(self.audiopaths_and_text[index])
if mel is None or \
(self.max_mel_len is not None and mel.shape[-1] > self.max_mel_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
if mel is not None:
print(f"Exception {index} mel_len:{mel.shape[-1]} text_len:{tseq.shape[0]} fname: {path}")
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
rv = random.randint(0,len(self)-1)
return self[rv]
orig_output = m.shape[-1]
orig_text_len = t.shape[0]
orig_output = mel.shape[-1]
orig_text_len = tseq.shape[0]
if not self.needs_collate:
if m.shape[-1] != self.max_mel_len:
m = F.pad(m, (0, self.max_mel_len - m.shape[-1]))
if t.shape[0] != self.max_text_len:
t = F.pad(t, (0, self.max_text_len - t.shape[0]))
if mel.shape[-1] != self.max_mel_len:
mel = F.pad(mel, (0, self.max_mel_len - mel.shape[-1]))
if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
return {
'padded_text': t,
'real_text': text,
'padded_text': tseq,
'input_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'padded_mel': m,
'padded_mel': mel,
'output_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': p
'filenames': path
}
return t, m, p
return tseq, mel, path, text
def __len__(self):
return len(self.audiopaths_and_text)
@ -156,10 +157,12 @@ class TextMelCollate():
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
filenames = []
real_text = []
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text
filenames.append(batch[ids_sorted_decreasing[i]][2])
real_text.append(batch[ids_sorted_decreasing[i]][3])
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
@ -186,7 +189,8 @@ class TextMelCollate():
'padded_mel': mel_padded,
'padded_gate': gate_padded,
'output_lengths': output_lengths,
'filenames': filenames
'filenames': filenames,
'real_text': real_text,
}

View File

@ -5,7 +5,7 @@ from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from models.tacotron2.text import symbols, sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
@ -125,11 +125,11 @@ class GptAsr(nn.Module):
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
emb = torch.cat([text_emb, mel_emb], dim=1)
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
text_logits = self.text_head(text_logits)
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width]
@ -140,15 +140,12 @@ class GptAsr(nn.Module):
text_seq = text_seq[sort_indices]
text_seq = text_seq[:beam_width]
if torch.all(torch.any(text_seq == self.MEL_STOP_TOKEN, dim=1)):
# PAD doubles as a stop token. PAD=0.
if torch.all(torch.any(text_seq == 0, dim=1)):
break
if text_seq.shape[1] >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
text_seq = text_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
text_seq = text_seq * (text_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
return text_seq

View File

@ -1,5 +1,8 @@
""" from https://github.com/keithito/tacotron """
import re
import torch
from models.tacotron2.text import cleaners
from models.tacotron2.text.symbols import symbols
@ -44,6 +47,8 @@ def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if isinstance(symbol_id, torch.Tensor):
symbol_id = symbol_id.item()
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:

View File

@ -0,0 +1,76 @@
import os.path as osp
import logging
import random
import argparse
import torchvision
import utils
import utils.options as option
import utils.util as util
from models.tacotron2.text import sequence_to_text
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from scipy.io import wavfile
def forward_pass(model, data, output_dir, opt, b):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
real = data[opt['eval']['real_text']][0]
pred_seq = model.eval_state[opt['eval']['gen_text']][0]
pred_text = [sequence_to_text(ts) for ts in pred_seq]
audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio)
print(f'{b} Real text: "{real}"')
for i, text in enumerate(pred_text):
print(f'{b} Predicted text {i}: "{text}"')
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mozcv.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt)
batch = 0
for test_loader in test_loaders:
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
tq = tqdm(test_loader)
for data in tq:
forward_pass(model, data, dataset_dir, opt, batch)
batch += 1