From 007976082b0a01973d3438974ec84141205da241 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 14 Aug 2021 14:37:17 -0600 Subject: [PATCH] GPT_asr for inference --- codes/data/audio/nv_tacotron_dataset.py | 42 +++++----- codes/models/gpt_voice/gpt_asr.py | 19 ++--- codes/models/tacotron2/text/__init__.py | 5 ++ .../audio/test_audio_speech_recognition.py | 76 +++++++++++++++++++ 4 files changed, 112 insertions(+), 30 deletions(-) create mode 100644 codes/scripts/audio/test_audio_speech_recognition.py diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index b9461977..8257e080 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -62,9 +62,9 @@ class TextMelLoader(torch.utils.data.Dataset): # separate filename and text audiopath, text = audiopath_and_text[0], audiopath_and_text[1] audiopath = os.path.join(self.path, audiopath) - text = self.get_text(text) + text_seq = self.get_text(text) mel = self.get_mel(audiopath) - return (text, mel, audiopath_and_text[0]) + return (text_seq, mel, text, audiopath_and_text[0]) def get_mel(self, filename): if not self.load_mel_from_disk: @@ -106,30 +106,31 @@ class TextMelLoader(torch.utils.data.Dataset): return text_norm def __getitem__(self, index): - t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) - if m is None or \ - (self.max_mel_len is not None and m.shape[-1] > self.max_mel_len) or \ - (self.max_text_len is not None and t.shape[0] > self.max_text_len): - if m is not None: - print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}") + tseq, mel, text, path = self.get_mel_text_pair(self.audiopaths_and_text[index]) + if mel is None or \ + (self.max_mel_len is not None and mel.shape[-1] > self.max_mel_len) or \ + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + if mel is not None: + print(f"Exception {index} mel_len:{mel.shape[-1]} text_len:{tseq.shape[0]} fname: {path}") # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. rv = random.randint(0,len(self)-1) return self[rv] - orig_output = m.shape[-1] - orig_text_len = t.shape[0] + orig_output = mel.shape[-1] + orig_text_len = tseq.shape[0] if not self.needs_collate: - if m.shape[-1] != self.max_mel_len: - m = F.pad(m, (0, self.max_mel_len - m.shape[-1])) - if t.shape[0] != self.max_text_len: - t = F.pad(t, (0, self.max_text_len - t.shape[0])) + if mel.shape[-1] != self.max_mel_len: + mel = F.pad(mel, (0, self.max_mel_len - mel.shape[-1])) + if tseq.shape[0] != self.max_text_len: + tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) return { - 'padded_text': t, + 'real_text': text, + 'padded_text': tseq, 'input_lengths': torch.tensor(orig_text_len, dtype=torch.long), - 'padded_mel': m, + 'padded_mel': mel, 'output_lengths': torch.tensor(orig_output, dtype=torch.long), - 'filenames': p + 'filenames': path } - return t, m, p + return tseq, mel, path, text def __len__(self): return len(self.audiopaths_and_text) @@ -156,10 +157,12 @@ class TextMelCollate(): text_padded = torch.LongTensor(len(batch), max_input_len) text_padded.zero_() filenames = [] + real_text = [] for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, :text.size(0)] = text filenames.append(batch[ids_sorted_decreasing[i]][2]) + real_text.append(batch[ids_sorted_decreasing[i]][3]) # Right zero-pad mel-spec num_mels = batch[0][1].size(0) @@ -186,7 +189,8 @@ class TextMelCollate(): 'padded_mel': mel_padded, 'padded_gate': gate_padded, 'output_lengths': output_lengths, - 'filenames': filenames + 'filenames': filenames, + 'real_text': real_text, } diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index af4995a7..c63b2b71 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -5,7 +5,7 @@ from munch import munchify from models.gpt_voice.lucidrains_gpt import Transformer from models.tacotron2.taco_utils import get_mask_from_lengths -from models.tacotron2.text import symbols +from models.tacotron2.text import symbols, sequence_to_text from trainer.networks import register_model from utils.util import opt_get @@ -125,11 +125,11 @@ class GptAsr(nn.Module): text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) if text_emb.shape[0] != mel_emb.shape[0]: mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1) - emb = torch.cat([text_emb, mel_emb], dim=1) + emb = torch.cat([mel_emb, text_emb], dim=1) enc = self.gpt(emb) - mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) - mel_logits = self.mel_head(mel_logits) - topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width) + text_logits = self.final_norm(enc[:, mel_emb.shape[1]:]) + text_logits = self.text_head(text_logits) + topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width) probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) probabilities, sort_indices = torch.sort(probabilities, descending=True) probabilities = probabilities[:beam_width] @@ -140,15 +140,12 @@ class GptAsr(nn.Module): text_seq = text_seq[sort_indices] text_seq = text_seq[:beam_width] - if torch.all(torch.any(text_seq == self.MEL_STOP_TOKEN, dim=1)): + # PAD doubles as a stop token. PAD=0. + if torch.all(torch.any(text_seq == 0, dim=1)): break if text_seq.shape[1] >= self.max_mel_frames: - print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - - # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) - text_seq = text_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT - text_seq = text_seq * (text_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens. + print("Warning! Encountered frame limit before a pad token. Output is likely wrong.") return text_seq diff --git a/codes/models/tacotron2/text/__init__.py b/codes/models/tacotron2/text/__init__.py index 6d46acb9..a392b10a 100644 --- a/codes/models/tacotron2/text/__init__.py +++ b/codes/models/tacotron2/text/__init__.py @@ -1,5 +1,8 @@ """ from https://github.com/keithito/tacotron """ import re + +import torch + from models.tacotron2.text import cleaners from models.tacotron2.text.symbols import symbols @@ -44,6 +47,8 @@ def sequence_to_text(sequence): '''Converts a sequence of IDs back to a string''' result = '' for symbol_id in sequence: + if isinstance(symbol_id, torch.Tensor): + symbol_id = symbol_id.item() if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: diff --git a/codes/scripts/audio/test_audio_speech_recognition.py b/codes/scripts/audio/test_audio_speech_recognition.py new file mode 100644 index 00000000..28d2532f --- /dev/null +++ b/codes/scripts/audio/test_audio_speech_recognition.py @@ -0,0 +1,76 @@ +import os.path as osp +import logging +import random +import argparse + +import torchvision + +import utils +import utils.options as option +import utils.util as util +from models.tacotron2.text import sequence_to_text +from trainer.ExtensibleTrainer import ExtensibleTrainer +from data import create_dataset, create_dataloader +from tqdm import tqdm +import torch +import numpy as np +from scipy.io import wavfile + + +def forward_pass(model, data, output_dir, opt, b): + with torch.no_grad(): + model.feed_data(data, 0) + model.test() + + real = data[opt['eval']['real_text']][0] + pred_seq = model.eval_state[opt['eval']['gen_text']][0] + pred_text = [sequence_to_text(ts) for ts in pred_seq] + audio = model.eval_state[opt['eval']['audio']][0].cpu().numpy() + wavfile.write(osp.join(output_dir, f'{b}_clip.wav'), 22050, audio) + print(f'{b} Real text: "{real}"') + for i, text in enumerate(pred_text): + print(f'{b} Predicted text {i}: "{text}"') + + +if __name__ == "__main__": + # Set seeds + torch.manual_seed(5555) + random.seed(5555) + np.random.seed(5555) + + #### options + torch.backends.cudnn.benchmark = True + want_metrics = False + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mozcv.yml') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + utils.util.loaded_options = opt + + util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) + test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) + logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + + model = ExtensibleTrainer(opt) + + batch = 0 + for test_loader in test_loaders: + dataset_dir = opt['path']['results_root'] + util.mkdir(dataset_dir) + + tq = tqdm(test_loader) + for data in tq: + forward_pass(model, data, dataset_dir, opt, batch) + batch += 1 +