forked from mrq/DL-Art-School
Improve asr_eval
This commit is contained in:
parent
312f631c5b
commit
a5b4bee719
|
@ -297,17 +297,11 @@ class GptAsrHf2(nn.Module):
|
||||||
self.inference_model.store_mel_emb(mel_emb)
|
self.inference_model.store_mel_emb(mel_emb)
|
||||||
|
|
||||||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||||
if cond_text is None:
|
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
fake_inputs[:,-1] = self.start_token
|
||||||
fake_inputs[:,-1] = self.start_token
|
|
||||||
else:
|
|
||||||
cond_used = 10
|
|
||||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
|
||||||
fake_inputs[:,-1-cond_used] = self.start_token
|
|
||||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
|
||||||
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||||
return gen[:, mel_emb.shape[1]:]
|
return gen[:, mel_emb.shape[1]+1:]
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -18,19 +18,25 @@ import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(model, data, output_dir, opt, b):
|
def forward_pass(model, data, output_dir, opt, macro_b, dataset):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.feed_data(data, 0)
|
model.feed_data(data, 0)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
if 'real_text' in opt['eval'].keys():
|
|
||||||
real = data[opt['eval']['real_text']][0]
|
|
||||||
print(f'{b} Real text: "{real}"')
|
|
||||||
|
|
||||||
gt_key = opt['eval']['gen_text']
|
gt_key = opt['eval']['gen_text']
|
||||||
txts = []
|
txts = []
|
||||||
for b in range(model.eval_state[gt_key][0].shape[0]):
|
for b in range(model.eval_state[gt_key][0].shape[0]):
|
||||||
txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b]))
|
if 'real_text' in opt['eval'].keys():
|
||||||
|
real = data[opt['eval']['real_text']][b]
|
||||||
|
print(f'{macro_b} {b} Real text: "{real}"')
|
||||||
|
|
||||||
|
codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu()
|
||||||
|
if hasattr(dataset, 'tokenizer'):
|
||||||
|
text = dataset.tokenizer.decode(codes.numpy())
|
||||||
|
text = text.replace(' $$$', '')
|
||||||
|
txts.append(text)
|
||||||
|
else:
|
||||||
|
txts.append(sequence_to_text(codes))
|
||||||
return txts
|
return txts
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,30 +63,26 @@ if __name__ == "__main__":
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
logger.info(option.dict2str(opt))
|
logger.info(option.dict2str(opt))
|
||||||
|
|
||||||
test_loaders = []
|
dataset_opt = opt['datasets']['val']
|
||||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||||
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||||
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||||
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
|
||||||
test_loaders.append(test_loader)
|
|
||||||
|
|
||||||
model = ExtensibleTrainer(opt)
|
model = ExtensibleTrainer(opt)
|
||||||
|
|
||||||
batch = 0
|
batch = 0
|
||||||
output = open('results.tsv', 'w')
|
output = open('results.tsv', 'w')
|
||||||
for test_loader in test_loaders:
|
dataset_dir = opt['path']['results_root']
|
||||||
dataset_dir = opt['path']['results_root']
|
util.mkdir(dataset_dir)
|
||||||
util.mkdir(dataset_dir)
|
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
for data in tqdm(test_loader):
|
||||||
for data in tq:
|
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||||
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
# continue
|
||||||
# continue
|
preds = forward_pass(model, data, dataset_dir, opt, batch, test_set)
|
||||||
preds = forward_pass(model, data, dataset_dir, opt, batch)
|
for b, pred in enumerate(preds):
|
||||||
for b, pred in enumerate(preds):
|
pred = pred.replace('_', '')
|
||||||
pred = pred.replace('_', '')
|
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
||||||
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
|
print(pred)
|
||||||
print(pred)
|
batch += 1
|
||||||
batch += 1
|
output.flush()
|
||||||
output.flush()
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user