Working gpt_asr_hf inference - and it's a beast!

This commit is contained in:
James Betker 2021-11-06 21:47:15 -06:00
parent 596a62fe01
commit 756b4dad09
3 changed files with 24 additions and 14 deletions

View File

@ -1,3 +1,5 @@
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -145,7 +147,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.transformer.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat(text_emb.shape[0], 1, 1)
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
@ -264,17 +266,18 @@ class GptAsrHf(nn.Module):
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
if cond_text is None:
fake_inputs = torch.full((1,self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
else:
cond_used = 10
fake_inputs = torch.full((1,self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=False)
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
@register_model
def register_gpt_asr_hf(opt_net, opt):
return GptAsrHf(**opt_get(opt_net, ['kwargs'], {}))
@ -296,8 +299,11 @@ def distill():
if __name__ == '__main__':
gpt = GptAsrHf(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2)
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
start = time()
gpt.inference(torch.randn(1,80,350), num_beams=1)
print(f"Elapsed: {time()-start}")
'''
with torch.no_grad():

View File

@ -27,8 +27,11 @@ def forward_pass(model, data, output_dir, opt, b):
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
pred_seq = model.eval_state[opt['eval']['gen_text']][0][0] # Grab first sequence, which should represent the most likely sequence.
return sequence_to_text(pred_seq)
gt_key = opt['eval']['gen_text']
txts = []
for b in range(model.eval_state[gt_key][0].shape[0]):
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
return txts
if __name__ == "__main__":
@ -73,10 +76,11 @@ if __name__ == "__main__":
for data in tq:
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue
pred = forward_pass(model, data, dataset_dir, opt, batch)
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][0])}\n')
print(pred)
preds = forward_pass(model, data, dataset_dir, opt, batch)
for b, pred in enumerate(preds):
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
print(pred)
batch += 1
output.flush()
batch += 1

View File

@ -57,7 +57,7 @@ class WordErrorRate:
if __name__ == '__main__':
inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv'
inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv'
libri_base = 'Z:\\libritts\\test-clean'
wer = WordErrorRate()