Add & refine WER evaluator for w2v

This commit is contained in:
James Betker 2022-02-13 20:47:29 -07:00
parent e16af944c0
commit a4f1641eea
2 changed files with 57 additions and 3 deletions

View File

@ -11,6 +11,11 @@ from trainer.networks import register_model
from utils.util import opt_get
def only_letters(string):
allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'')
return ''.join(filter(allowlist.__contains__, string.upper()))
class Wav2VecWrapper(nn.Module):
"""
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
@ -77,8 +82,8 @@ class Wav2VecWrapper(nn.Module):
pred_strings = []
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
last_labels[last_labels == -100] = 0
label_strings.extend([sequence_to_text(lbl) for lbl in last_labels])
pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred])
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
res['wer'] = wer
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
@ -88,7 +93,7 @@ class Wav2VecWrapper(nn.Module):
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
pred = logits.argmax(dim=-1)
return self.decode_ctc(pred)
return [self.decode_ctc(p) for p in pred]
@register_model
@ -97,6 +102,7 @@ def register_wav2vec2_finetune(opt_net, opt):
if __name__ == '__main__':
print(only_letters("Hello, world!"))
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
w2v.get_debug_values(0,"")

View File

@ -0,0 +1,48 @@
from copy import deepcopy
from datasets import load_metric
import torch
import trainer.eval.evaluator as evaluator
from data import create_dataset, create_dataloader
from models.asr.w2v_wrapper import only_letters
from models.tacotron2.text import sequence_to_text
class WerEvaluator(evaluator.Evaluator):
"""
Evaluator that produces the WER for a speech recognition model on a test set.
"""
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env, uses_all_ddp=False)
self.clip_key = opt_eval['clip_key']
self.clip_lengths_key = opt_eval['clip_lengths_key']
self.text_seq_key = opt_eval['text_seq_key']
self.text_seq_lengths_key = opt_eval['text_seq_lengths_key']
self.wer_metric = load_metric('wer')
def perform_eval(self):
val_opt = deepcopy(self.env['opt']['datasets']['val'])
val_opt['batch_size'] = 1 # This is important to ensure no padding.
val_dataset, collate_fn = create_dataset(val_opt, return_collate=True)
val_loader = create_dataloader(val_dataset, val_opt, self.env['opt'], None, collate_fn=collate_fn)
model = self.model.module if hasattr(self.model, 'module') else self.model # Unwrap DDP models
model.eval()
with torch.no_grad():
preds = []
reals = []
for batch in val_loader:
clip = batch[self.clip_key]
assert clip.shape[0] == 1
clip_len = batch[self.clip_lengths_key][0]
clip = clip[:, :, :clip_len].cuda()
pred_seq = model.inference(clip)
preds.append(only_letters(sequence_to_text(pred_seq[0])))
real_seq = batch[self.text_seq_key]
real_seq_len = batch[self.text_seq_lengths_key][0]
real_seq = real_seq[:, :real_seq_len]
reals.append(only_letters(sequence_to_text(real_seq[0])))
wer = self.wer_metric.compute(predictions=preds, references=reals)
model.train()
return {'eval_wer': wer}