diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 1ba6217b..9a486e7b 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -1,3 +1,5 @@ +from time import time + import torch import torch.nn as nn import torch.nn.functional as F @@ -145,7 +147,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): text_emb = self.transformer.get_input_embeddings()(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat(text_emb.shape[0], 1, 1) + mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) else: mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) @@ -264,17 +266,18 @@ class GptAsrHf(nn.Module): # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. if cond_text is None: - fake_inputs = torch.full((1,self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) fake_inputs[:,-1] = self.NUMBER_SYMBOLS else: cond_used = 10 - fake_inputs = torch.full((1,self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, - max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=False) + max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) return gen[:, self.max_mel_frames:] + @register_model def register_gpt_asr_hf(opt_net, opt): return GptAsrHf(**opt_get(opt_net, ['kwargs'], {})) @@ -296,8 +299,11 @@ def distill(): if __name__ == '__main__': - gpt = GptAsrHf(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2) - l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) + gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) + #l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) + start = time() + gpt.inference(torch.randn(1,80,350), num_beams=1) + print(f"Elapsed: {time()-start}") ''' with torch.no_grad(): diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py index e459705b..96605a17 100644 --- a/codes/scripts/audio/asr_eval.py +++ b/codes/scripts/audio/asr_eval.py @@ -27,8 +27,11 @@ def forward_pass(model, data, output_dir, opt, b): real = data[opt['eval']['real_text']][0] print(f'{b} Real text: "{real}"') - pred_seq = model.eval_state[opt['eval']['gen_text']][0][0] # Grab first sequence, which should represent the most likely sequence. - return sequence_to_text(pred_seq) + gt_key = opt['eval']['gen_text'] + txts = [] + for b in range(model.eval_state[gt_key][0].shape[0]): + txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b])) + return txts if __name__ == "__main__": @@ -73,10 +76,11 @@ if __name__ == "__main__": for data in tq: #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: # continue - pred = forward_pass(model, data, dataset_dir, opt, batch) - pred = pred.replace('_', '') - output.write(f'{pred}\t{os.path.basename(data["filenames"][0])}\n') - print(pred) + preds = forward_pass(model, data, dataset_dir, opt, batch) + for b, pred in enumerate(preds): + pred = pred.replace('_', '') + output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n') + print(pred) + batch += 1 output.flush() - batch += 1 diff --git a/codes/scripts/audio/word_error_rate.py b/codes/scripts/audio/word_error_rate.py index bcc257a8..5f663230 100644 --- a/codes/scripts/audio/word_error_rate.py +++ b/codes/scripts/audio/word_error_rate.py @@ -57,7 +57,7 @@ class WordErrorRate: if __name__ == '__main__': - inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv' + inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv' libri_base = 'Z:\\libritts\\test-clean' wer = WordErrorRate()