Improve asr_eval

This commit is contained in:
James Betker 2021-12-28 11:45:15 -07:00
parent 312f631c5b
commit a5b4bee719
2 changed files with 31 additions and 35 deletions

View File

@ -297,17 +297,11 @@ class GptAsrHf2(nn.Module):
self.inference_model.store_mel_emb(mel_emb)
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
if cond_text is None:
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.start_token
else:
cond_used = 10
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1-cond_used] = self.start_token
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.start_token
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, mel_emb.shape[1]:]
return gen[:, mel_emb.shape[1]+1:]
@register_model

View File

@ -18,19 +18,25 @@ import numpy as np
from scipy.io import wavfile
def forward_pass(model, data, output_dir, opt, b):
def forward_pass(model, data, output_dir, opt, macro_b, dataset):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][0]
print(f'{b} Real text: "{real}"')
gt_key = opt['eval']['gen_text']
txts = []
for b in range(model.eval_state[gt_key][0].shape[0]):
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][b]
print(f'{macro_b} {b} Real text: "{real}"')
codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu()
if hasattr(dataset, 'tokenizer'):
text = dataset.tokenizer.decode(codes.numpy())
text = text.replace(' $$$', '')
txts.append(text)
else:
txts.append(sequence_to_text(codes))
return txts
@ -57,30 +63,26 @@ if __name__ == "__main__":
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
dataset_opt = opt['datasets']['val']
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
model = ExtensibleTrainer(opt)
batch = 0
output = open('results.tsv', 'w')
for test_loader in test_loaders:
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
tq = tqdm(test_loader)
for data in tq:
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue
preds = forward_pass(model, data, dataset_dir, opt, batch)
for b, pred in enumerate(preds):
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
print(pred)
batch += 1
output.flush()
for data in tqdm(test_loader):
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue
preds = forward_pass(model, data, dataset_dir, opt, batch, test_set)
for b, pred in enumerate(preds):
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
print(pred)
batch += 1
output.flush()