diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index 153a8385..0348959e 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -11,6 +11,11 @@ from trainer.networks import register_model from utils.util import opt_get +def only_letters(string): + allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'') + return ''.join(filter(allowlist.__contains__, string.upper())) + + class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. @@ -77,8 +82,8 @@ class Wav2VecWrapper(nn.Module): pred_strings = [] for last_labels, last_pred in zip(self.last_labels, self.last_pred): last_labels[last_labels == -100] = 0 - label_strings.extend([sequence_to_text(lbl) for lbl in last_labels]) - pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred]) + label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels]) + pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred]) wer = wer_metric.compute(predictions=pred_strings, references=label_strings) res['wer'] = wer print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") @@ -88,7 +93,7 @@ class Wav2VecWrapper(nn.Module): audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) logits = self.w2v(input_values=audio_norm.squeeze(1)).logits pred = logits.argmax(dim=-1) - return self.decode_ctc(pred) + return [self.decode_ctc(p) for p in pred] @register_model @@ -97,6 +102,7 @@ def register_wav2vec2_finetune(opt_net, opt): if __name__ == '__main__': + print(only_letters("Hello, world!")) w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True) loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50])) w2v.get_debug_values(0,"") diff --git a/codes/trainer/eval/eval_wer.py b/codes/trainer/eval/eval_wer.py new file mode 100644 index 00000000..604965f0 --- /dev/null +++ b/codes/trainer/eval/eval_wer.py @@ -0,0 +1,48 @@ +from copy import deepcopy + +from datasets import load_metric + +import torch +import trainer.eval.evaluator as evaluator +from data import create_dataset, create_dataloader +from models.asr.w2v_wrapper import only_letters +from models.tacotron2.text import sequence_to_text + + +class WerEvaluator(evaluator.Evaluator): + """ + Evaluator that produces the WER for a speech recognition model on a test set. + """ + def __init__(self, model, opt_eval, env): + super().__init__(model, opt_eval, env, uses_all_ddp=False) + self.clip_key = opt_eval['clip_key'] + self.clip_lengths_key = opt_eval['clip_lengths_key'] + self.text_seq_key = opt_eval['text_seq_key'] + self.text_seq_lengths_key = opt_eval['text_seq_lengths_key'] + self.wer_metric = load_metric('wer') + + def perform_eval(self): + val_opt = deepcopy(self.env['opt']['datasets']['val']) + val_opt['batch_size'] = 1 # This is important to ensure no padding. + val_dataset, collate_fn = create_dataset(val_opt, return_collate=True) + val_loader = create_dataloader(val_dataset, val_opt, self.env['opt'], None, collate_fn=collate_fn) + model = self.model.module if hasattr(self.model, 'module') else self.model # Unwrap DDP models + model.eval() + with torch.no_grad(): + preds = [] + reals = [] + for batch in val_loader: + clip = batch[self.clip_key] + assert clip.shape[0] == 1 + clip_len = batch[self.clip_lengths_key][0] + clip = clip[:, :, :clip_len].cuda() + pred_seq = model.inference(clip) + preds.append(only_letters(sequence_to_text(pred_seq[0]))) + real_seq = batch[self.text_seq_key] + real_seq_len = batch[self.text_seq_lengths_key][0] + real_seq = real_seq[:, :real_seq_len] + reals.append(only_letters(sequence_to_text(real_seq[0]))) + wer = self.wer_metric.compute(predictions=preds, references=reals) + model.train() + return {'eval_wer': wer} +