From 986fc9628d8c0e3420e423ad37438842d7f86202 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 29 Oct 2021 17:21:40 -0600 Subject: [PATCH] Check in GPT with new inference methods (but not the backing code..) --- codes/models/gpt_voice/gpt_asr.py | 89 +++++++++++++++++++++++++------ codes/scripts/audio/asr_eval.py | 6 +-- 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index 76a99c34..fa1df98f 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -1,3 +1,5 @@ +from time import time + import torch import torch.nn as nn import torch.nn.functional as F @@ -92,12 +94,12 @@ class GptAsr(nn.Module): loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) return loss_text.mean(), text_logits - def inference_beam_topk(self, mel): + def inference_beam_topk(self, mel, fn='inference_beam'): def topk_sampler(distribution, k): return torch.topk(distribution, k=k, dim=-1) - return self.inference_beam(mel, topk_sampler) + return getattr(self, fn)(mel, topk_sampler) - def inference_beam_sampled(self, mel): + def inference_beam_sampled(self, mel, fn='inference_beam'): def multinomial_sampler(distribution, k): indices = torch.multinomial(distribution, num_samples=k, replacement=False) values = torch.gather(distribution, dim=1, index=indices) @@ -106,7 +108,7 @@ class GptAsr(nn.Module): self.indices = i self.values = v return container(indices, values) - return self.inference_beam(mel, multinomial_sampler) + return getattr(self, fn)(mel, multinomial_sampler) def inference_beam(self, mel_inputs, sampler_fn): beam_width = 16 @@ -151,33 +153,90 @@ class GptAsr(nn.Module): return text_seq + def inference_beam_opt(self, mel_inputs, sampler_fn): + beam_width = 16 + temperature = .8 + + b, _, s = mel_inputs.shape + assert b == 1 # Beam search only works on batches of one. + mel_emb = self.mel_encoder(mel_inputs) + mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) + mel_emb = mel_emb.permute(0,2,1).contiguous() + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + + intermediates = [] + text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) + probabilities = torch.ones((b,), device=mel_emb.device) + while text_seq.shape[-1] < self.max_symbols_per_phrase: + text_emb = self.text_embedding(text_seq) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) + if text_emb.shape[0] != mel_emb.shape[0]: + mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1) + emb = torch.cat([mel_emb, text_emb], dim=1) + + if len(intermediates) == 0: + enc, intermediates = self.gpt(emb, return_intermediates=True) + intermediates = [(i[0].repeat(beam_width, 1, 1), + i[1].repeat(beam_width, 1, 1)) for i in intermediates] + else: + enc, intermediates = self.gpt.infer_last_two(emb, intermediates) + + text_logits = self.final_norm(enc[:, mel_emb.shape[1]:]) + text_logits = self.text_head(text_logits) + topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width) + probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) + probabilities, sort_indices = torch.sort(probabilities, descending=True) + probabilities = probabilities[:beam_width] + + text_seq = text_seq.repeat_interleave(beam_width, dim=0) + codes = topk.indices.flatten() + text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1) + text_seq = text_seq[sort_indices] + text_seq = text_seq[:beam_width] + + # PAD doubles as a stop token. PAD=0. + if torch.all(torch.any(text_seq == 0, dim=1)): + break + + if text_seq.shape[1] >= self.max_mel_frames: + print("Warning! Encountered frame limit before a pad token. Output is likely wrong.") + + return text_seq + + @register_model def register_gpt_asr(opt_net, opt): return GptAsr(**opt_get(opt_net, ['kwargs'], {})) -# Halves the number of layers in the provided model. -def distill(model): +# Quick script that loads a model and halves the number of layers, then saves that model. +def distill(): + gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) + gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) rc = 0 i = 0 - while i < len(model.gpt.layers.layers): + while i < len(gpt.gpt.layers.layers): if rc % 2 != 0: - del model.gpt.layers.layers[i] + del gpt.gpt.layers.layers[i] else: i += 1 rc += 1 - return model + torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') if __name__ == '__main__': - gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) - gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) - student = distill(gpt) - torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') + gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda() #l = gpt(torch.randn(2,80,800), # torch.randint(high=len(symbols), size=(2,180))) - #o = gpt.infer(torch.randint(high=24, size=(2,60))) - #print(o.shape) + with torch.no_grad(): + t = torch.randn(1,80,800).cuda() + start = time() + s = gpt.inference_beam_topk(t) + print(time()-start) + + start = time() + o = gpt.inference_beam_topk(t, fn='inference_beam_opt') + print(time()-start) diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py index 9a74ab70..bcb6b917 100644 --- a/codes/scripts/audio/asr_eval.py +++ b/codes/scripts/audio/asr_eval.py @@ -41,7 +41,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -71,8 +71,8 @@ if __name__ == "__main__": tq = tqdm(test_loader) for data in tq: - #if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: - # continue + if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: + continue pred = forward_pass(model, data, dataset_dir, opt, batch) pred = pred.replace('_', '') output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')