forked from mrq/DL-Art-School
Working gpt_asr_hf inference - and it's a beast!
This commit is contained in:
parent
596a62fe01
commit
756b4dad09
|
@ -1,3 +1,5 @@
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -145,7 +147,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
text_emb = self.transformer.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
|
||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||
mel_emb = self.cached_mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||
else:
|
||||
mel_emb = self.cached_mel_emb
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
@ -264,17 +266,18 @@ class GptAsrHf(nn.Module):
|
|||
|
||||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||
if cond_text is None:
|
||||
fake_inputs = torch.full((1,self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
|
||||
else:
|
||||
cond_used = 10
|
||||
fake_inputs = torch.full((1,self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
|
||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=False)
|
||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, self.max_mel_frames:]
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_asr_hf(opt_net, opt):
|
||||
return GptAsrHf(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
@ -296,8 +299,11 @@ def distill():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptAsrHf(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2)
|
||||
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||
start = time()
|
||||
gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||
print(f"Elapsed: {time()-start}")
|
||||
|
||||
'''
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -27,8 +27,11 @@ def forward_pass(model, data, output_dir, opt, b):
|
|||
real = data[opt['eval']['real_text']][0]
|
||||
print(f'{b} Real text: "{real}"')
|
||||
|
||||
pred_seq = model.eval_state[opt['eval']['gen_text']][0][0] # Grab first sequence, which should represent the most likely sequence.
|
||||
return sequence_to_text(pred_seq)
|
||||
gt_key = opt['eval']['gen_text']
|
||||
txts = []
|
||||
for b in range(model.eval_state[gt_key][0].shape[0]):
|
||||
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
|
||||
return txts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -73,10 +76,11 @@ if __name__ == "__main__":
|
|||
for data in tq:
|
||||
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||
# continue
|
||||
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
||||
pred = pred.replace('_', '')
|
||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][0])}\n')
|
||||
print(pred)
|
||||
preds = forward_pass(model, data, dataset_dir, opt, batch)
|
||||
for b, pred in enumerate(preds):
|
||||
pred = pred.replace('_', '')
|
||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
||||
print(pred)
|
||||
batch += 1
|
||||
output.flush()
|
||||
batch += 1
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class WordErrorRate:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv'
|
||||
inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv'
|
||||
libri_base = 'Z:\\libritts\\test-clean'
|
||||
|
||||
wer = WordErrorRate()
|
||||
|
|
Loading…
Reference in New Issue
Block a user