Improve asr_eval

This commit is contained in:
James Betker 2021-12-28 11:45:15 -07:00
parent 312f631c5b
commit a5b4bee719
2 changed files with 31 additions and 35 deletions

View File

@ -297,17 +297,11 @@ class GptAsrHf2(nn.Module):
self.inference_model.store_mel_emb(mel_emb) self.inference_model.store_mel_emb(mel_emb)
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
if cond_text is None:
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.start_token fake_inputs[:,-1] = self.start_token
else:
cond_used = 10
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1-cond_used] = self.start_token
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0, gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, mel_emb.shape[1]:] return gen[:, mel_emb.shape[1]+1:]
@register_model @register_model

View File

@ -18,19 +18,25 @@ import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
def forward_pass(model, data, output_dir, opt, b): def forward_pass(model, data, output_dir, opt, macro_b, dataset):
with torch.no_grad(): with torch.no_grad():
model.feed_data(data, 0) model.feed_data(data, 0)
model.test() model.test()
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
gt_key = opt['eval']['gen_text'] gt_key = opt['eval']['gen_text']
txts = [] txts = []
for b in range(model.eval_state[gt_key][0].shape[0]): for b in range(model.eval_state[gt_key][0].shape[0]):
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b])) if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][b]
print(f'{macro_b} {b} Real text: "{real}"')
codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu()
if hasattr(dataset, 'tokenizer'):
text = dataset.tokenizer.decode(codes.numpy())
text = text.replace(' $$$', '')
txts.append(text)
else:
txts.append(sequence_to_text(codes))
return txts return txts
@ -57,26 +63,22 @@ if __name__ == "__main__":
logger = logging.getLogger('base') logger = logging.getLogger('base')
logger.info(option.dict2str(opt)) logger.info(option.dict2str(opt))
test_loaders = [] dataset_opt = opt['datasets']['val']
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt) model = ExtensibleTrainer(opt)
batch = 0 batch = 0
output = open('results.tsv', 'w') output = open('results.tsv', 'w')
for test_loader in test_loaders:
dataset_dir = opt['path']['results_root'] dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir) util.mkdir(dataset_dir)
tq = tqdm(test_loader) for data in tqdm(test_loader):
for data in tq:
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue # continue
preds = forward_pass(model, data, dataset_dir, opt, batch) preds = forward_pass(model, data, dataset_dir, opt, batch, test_set)
for b, pred in enumerate(preds): for b, pred in enumerate(preds):
pred = pred.replace('_', '') pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n') output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')