forked from mrq/DL-Art-School
Working gpt_asr_hf inference - and it's a beast!
This commit is contained in:
parent
596a62fe01
commit
756b4dad09
|
@ -1,3 +1,5 @@
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -145,7 +147,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
text_emb = self.transformer.get_input_embeddings()(text_inputs)
|
text_emb = self.transformer.get_input_embeddings()(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
|
||||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||||
mel_emb = self.cached_mel_emb.repeat(text_emb.shape[0], 1, 1)
|
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||||
else:
|
else:
|
||||||
mel_emb = self.cached_mel_emb
|
mel_emb = self.cached_mel_emb
|
||||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
@ -264,17 +266,18 @@ class GptAsrHf(nn.Module):
|
||||||
|
|
||||||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||||
if cond_text is None:
|
if cond_text is None:
|
||||||
fake_inputs = torch.full((1,self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||||
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
|
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
|
||||||
else:
|
else:
|
||||||
cond_used = 10
|
cond_used = 10
|
||||||
fake_inputs = torch.full((1,self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||||
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
|
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
|
||||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
||||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=False)
|
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||||
return gen[:, self.max_mel_frames:]
|
return gen[:, self.max_mel_frames:]
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_gpt_asr_hf(opt_net, opt):
|
def register_gpt_asr_hf(opt_net, opt):
|
||||||
return GptAsrHf(**opt_get(opt_net, ['kwargs'], {}))
|
return GptAsrHf(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
@ -296,8 +299,11 @@ def distill():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = GptAsrHf(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2)
|
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||||
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||||
|
start = time()
|
||||||
|
gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||||
|
print(f"Elapsed: {time()-start}")
|
||||||
|
|
||||||
'''
|
'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -27,8 +27,11 @@ def forward_pass(model, data, output_dir, opt, b):
|
||||||
real = data[opt['eval']['real_text']][0]
|
real = data[opt['eval']['real_text']][0]
|
||||||
print(f'{b} Real text: "{real}"')
|
print(f'{b} Real text: "{real}"')
|
||||||
|
|
||||||
pred_seq = model.eval_state[opt['eval']['gen_text']][0][0] # Grab first sequence, which should represent the most likely sequence.
|
gt_key = opt['eval']['gen_text']
|
||||||
return sequence_to_text(pred_seq)
|
txts = []
|
||||||
|
for b in range(model.eval_state[gt_key][0].shape[0]):
|
||||||
|
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
|
||||||
|
return txts
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -73,10 +76,11 @@ if __name__ == "__main__":
|
||||||
for data in tq:
|
for data in tq:
|
||||||
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||||
# continue
|
# continue
|
||||||
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
preds = forward_pass(model, data, dataset_dir, opt, batch)
|
||||||
|
for b, pred in enumerate(preds):
|
||||||
pred = pred.replace('_', '')
|
pred = pred.replace('_', '')
|
||||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][0])}\n')
|
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
||||||
print(pred)
|
print(pred)
|
||||||
output.flush()
|
|
||||||
batch += 1
|
batch += 1
|
||||||
|
output.flush()
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ class WordErrorRate:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv'
|
inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv'
|
||||||
libri_base = 'Z:\\libritts\\test-clean'
|
libri_base = 'Z:\\libritts\\test-clean'
|
||||||
|
|
||||||
wer = WordErrorRate()
|
wer = WordErrorRate()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user