Improve asr_eval
This commit is contained in:
parent
312f631c5b
commit
a5b4bee719
|
@ -297,17 +297,11 @@ class GptAsrHf2(nn.Module):
|
|||
self.inference_model.store_mel_emb(mel_emb)
|
||||
|
||||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||
if cond_text is None:
|
||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_token
|
||||
else:
|
||||
cond_used = 10
|
||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1-cond_used] = self.start_token
|
||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_token
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
|
||||
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, mel_emb.shape[1]:]
|
||||
return gen[:, mel_emb.shape[1]+1:]
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -18,19 +18,25 @@ import numpy as np
|
|||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def forward_pass(model, data, output_dir, opt, b):
|
||||
def forward_pass(model, data, output_dir, opt, macro_b, dataset):
|
||||
with torch.no_grad():
|
||||
model.feed_data(data, 0)
|
||||
model.test()
|
||||
|
||||
if 'real_text' in opt['eval'].keys():
|
||||
real = data[opt['eval']['real_text']][0]
|
||||
print(f'{b} Real text: "{real}"')
|
||||
|
||||
gt_key = opt['eval']['gen_text']
|
||||
txts = []
|
||||
for b in range(model.eval_state[gt_key][0].shape[0]):
|
||||
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
|
||||
if 'real_text' in opt['eval'].keys():
|
||||
real = data[opt['eval']['real_text']][b]
|
||||
print(f'{macro_b} {b} Real text: "{real}"')
|
||||
|
||||
codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu()
|
||||
if hasattr(dataset, 'tokenizer'):
|
||||
text = dataset.tokenizer.decode(codes.numpy())
|
||||
text = text.replace(' $$$', '')
|
||||
txts.append(text)
|
||||
else:
|
||||
txts.append(sequence_to_text(codes))
|
||||
return txts
|
||||
|
||||
|
||||
|
@ -57,30 +63,26 @@ if __name__ == "__main__":
|
|||
logger = logging.getLogger('base')
|
||||
logger.info(option.dict2str(opt))
|
||||
|
||||
test_loaders = []
|
||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
dataset_opt = opt['datasets']['val']
|
||||
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
|
||||
model = ExtensibleTrainer(opt)
|
||||
|
||||
batch = 0
|
||||
output = open('results.tsv', 'w')
|
||||
for test_loader in test_loaders:
|
||||
dataset_dir = opt['path']['results_root']
|
||||
util.mkdir(dataset_dir)
|
||||
dataset_dir = opt['path']['results_root']
|
||||
util.mkdir(dataset_dir)
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
for data in tq:
|
||||
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||
# continue
|
||||
preds = forward_pass(model, data, dataset_dir, opt, batch)
|
||||
for b, pred in enumerate(preds):
|
||||
pred = pred.replace('_', '')
|
||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
||||
print(pred)
|
||||
batch += 1
|
||||
output.flush()
|
||||
for data in tqdm(test_loader):
|
||||
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||
# continue
|
||||
preds = forward_pass(model, data, dataset_dir, opt, batch, test_set)
|
||||
for b, pred in enumerate(preds):
|
||||
pred = pred.replace('_', '')
|
||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
||||
print(pred)
|
||||
batch += 1
|
||||
output.flush()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user