Update to evaluator

This commit is contained in:
James Betker 2022-02-17 17:30:33 -07:00
parent e1d71e1bd5
commit a813fbed9c

View File

@ -3,24 +3,47 @@ from copy import deepcopy
from datasets import load_metric from datasets import load_metric
import torch import torch
from tqdm import tqdm
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import trainer.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models.asr.w2v_wrapper import only_letters from models.asr.w2v_wrapper import only_letters, Wav2VecWrapper
from models.tacotron2.text import sequence_to_text from models.tacotron2.text import sequence_to_text
# Fine-tuned target for w2v-large: 4.487% WER. # Librispeech:
# baseline: .045% WER.
# fine-tuned new head (0): .054% WER
#
# baseline: .328
# 0: .342
# 24000: .346
def tacotron_detokenize(seq):
return only_letters(sequence_to_text(seq))
fb_processor = None
def fb_detokenize(seq):
global fb_processor
if fb_processor is None:
fb_processor = Wav2Vec2Processor.from_pretrained(f"facebook/wav2vec2-large-960h")
return fb_processor.decode(seq)
class WerEvaluator(evaluator.Evaluator): class WerEvaluator(evaluator.Evaluator):
""" """
Evaluator that produces the WER for a speech recognition model on a test set. Evaluator that produces the WER for a speech recognition model on a test set.
""" """
def __init__(self, model, opt_eval, env): def __init__(self, model, opt_eval, env, detokenizer_fn=tacotron_detokenize):
super().__init__(model, opt_eval, env, uses_all_ddp=False) super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.clip_key = opt_eval['clip_key'] self.clip_key = opt_eval['clip_key']
self.clip_lengths_key = opt_eval['clip_lengths_key'] self.clip_lengths_key = opt_eval['clip_lengths_key']
self.text_seq_key = opt_eval['text_seq_key'] self.text_seq_key = opt_eval['text_seq_key']
self.text_seq_lengths_key = opt_eval['text_seq_lengths_key'] self.text_seq_lengths_key = opt_eval['text_seq_lengths_key']
self.wer_metric = load_metric('wer') self.wer_metric = load_metric('wer')
self.detokenizer_fn = detokenizer_fn
def perform_eval(self): def perform_eval(self):
val_opt = deepcopy(self.env['opt']['datasets']['val']) val_opt = deepcopy(self.env['opt']['datasets']['val'])
@ -32,18 +55,55 @@ class WerEvaluator(evaluator.Evaluator):
with torch.no_grad(): with torch.no_grad():
preds = [] preds = []
reals = [] reals = []
for batch in val_loader: for batch in tqdm(val_loader):
clip = batch[self.clip_key] clip = batch[self.clip_key]
assert clip.shape[0] == 1 assert clip.shape[0] == 1
clip_len = batch[self.clip_lengths_key][0]
clip = clip[:, :, :clip_len].cuda()
pred_seq = model.inference(clip)
preds.append(only_letters(sequence_to_text(pred_seq[0])))
real_seq = batch[self.text_seq_key] real_seq = batch[self.text_seq_key]
real_seq_len = batch[self.text_seq_lengths_key][0] real_seq_len = batch[self.text_seq_lengths_key][0]
real_seq = real_seq[:, :real_seq_len] real_seq = real_seq[:, :real_seq_len]
reals.append(only_letters(sequence_to_text(real_seq[0]))) real_str = only_letters(sequence_to_text(real_seq[0]))
if len(real_str) > 0:
reals.append(real_str)
else:
continue # The WER computer doesn't like this scenario.
clip_len = batch[self.clip_lengths_key][0]
clip = clip[:, :, :clip_len].cuda()
pred_seq = model.inference(clip)
preds.append(self.detokenizer_fn(pred_seq[0]))
wer = self.wer_metric.compute(predictions=preds, references=reals) wer = self.wer_metric.compute(predictions=preds, references=reals)
model.train() model.train()
return {'eval_wer': wer} return {'eval_wer': wer}
if __name__ == '__main__':
env = { 'opt': {
'datasets': {
'val': {
'name': 'mass_test',
'n_workers': 1,
'batch_size': 1,
'mode': 'paired_voice_audio',
'sample_rate': 16000,
'path': ['y:/bigasr_dataset/mozcv/en/test.tsv'],
'fetcher_mode': ['mozilla_cv'],
'max_wav_length': 200000,
'use_bpe_tokenizer': False,
'max_text_length': 400,
'load_conditioning': False,
'phase': 'eval',
}
}
}}
opt_eval = {
'clip_key': 'wav',
'clip_lengths_key': 'wav_lengths',
'text_seq_key': 'padded_text',
'text_seq_lengths_key': 'text_lengths',
}
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False)
model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h')
weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth')
#model.load_state_dict(weights)
model = model.cuda()
eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize)
print(eval.perform_eval())