From a5b4bee719530d8fd4c865fbfabae02003b3a644 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 28 Dec 2021 11:45:15 -0700 Subject: [PATCH] Improve asr_eval --- codes/models/gpt_voice/gpt_asr_hf2.py | 12 ++---- codes/scripts/audio/asr_eval.py | 54 ++++++++++++++------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 46c1d83d..246cde02 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -297,17 +297,11 @@ class GptAsrHf2(nn.Module): self.inference_model.store_mel_emb(mel_emb) # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. - if cond_text is None: - fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1] = self.start_token - else: - cond_used = 10 - fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1-cond_used] = self.start_token - fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] + fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs[:,-1] = self.start_token gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0, max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) - return gen[:, mel_emb.shape[1]:] + return gen[:, mel_emb.shape[1]+1:] @register_model diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py index 7aa9976c..a16045a7 100644 --- a/codes/scripts/audio/asr_eval.py +++ b/codes/scripts/audio/asr_eval.py @@ -18,19 +18,25 @@ import numpy as np from scipy.io import wavfile -def forward_pass(model, data, output_dir, opt, b): +def forward_pass(model, data, output_dir, opt, macro_b, dataset): with torch.no_grad(): model.feed_data(data, 0) model.test() - if 'real_text' in opt['eval'].keys(): - real = data[opt['eval']['real_text']][0] - print(f'{b} Real text: "{real}"') - gt_key = opt['eval']['gen_text'] txts = [] for b in range(model.eval_state[gt_key][0].shape[0]): - txts.append(sequence_to_text(model.eval_state[opt['eval']['gen_text']][0][b])) + if 'real_text' in opt['eval'].keys(): + real = data[opt['eval']['real_text']][b] + print(f'{macro_b} {b} Real text: "{real}"') + + codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu() + if hasattr(dataset, 'tokenizer'): + text = dataset.tokenizer.decode(codes.numpy()) + text = text.replace(' $$$', '') + txts.append(text) + else: + txts.append(sequence_to_text(codes)) return txts @@ -57,30 +63,26 @@ if __name__ == "__main__": logger = logging.getLogger('base') logger.info(option.dict2str(opt)) - test_loaders = [] - for phase, dataset_opt in sorted(opt['datasets'].items()): - test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) - test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) - logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) - test_loaders.append(test_loader) + dataset_opt = opt['datasets']['val'] + test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) + test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) + logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) model = ExtensibleTrainer(opt) batch = 0 output = open('results.tsv', 'w') - for test_loader in test_loaders: - dataset_dir = opt['path']['results_root'] - util.mkdir(dataset_dir) + dataset_dir = opt['path']['results_root'] + util.mkdir(dataset_dir) - tq = tqdm(test_loader) - for data in tq: - #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: - # continue - preds = forward_pass(model, data, dataset_dir, opt, batch) - for b, pred in enumerate(preds): - pred = pred.replace('_', '') - output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n') - print(pred) - batch += 1 - output.flush() + for data in tqdm(test_loader): + #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: + # continue + preds = forward_pass(model, data, dataset_dir, opt, batch, test_set) + for b, pred in enumerate(preds): + pred = pred.replace('_', '') + output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n') + print(pred) + batch += 1 + output.flush()