forked from mrq/DL-Art-School
Add & refine WER evaluator for w2v
This commit is contained in:
parent
e16af944c0
commit
a4f1641eea
|
@ -11,6 +11,11 @@ from trainer.networks import register_model
|
|||
from utils.util import opt_get
|
||||
|
||||
|
||||
def only_letters(string):
|
||||
allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'')
|
||||
return ''.join(filter(allowlist.__contains__, string.upper()))
|
||||
|
||||
|
||||
class Wav2VecWrapper(nn.Module):
|
||||
"""
|
||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||
|
@ -77,8 +82,8 @@ class Wav2VecWrapper(nn.Module):
|
|||
pred_strings = []
|
||||
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
|
||||
last_labels[last_labels == -100] = 0
|
||||
label_strings.extend([sequence_to_text(lbl) for lbl in last_labels])
|
||||
pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred])
|
||||
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
|
||||
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
|
||||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||
res['wer'] = wer
|
||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||
|
@ -88,7 +93,7 @@ class Wav2VecWrapper(nn.Module):
|
|||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
||||
pred = logits.argmax(dim=-1)
|
||||
return self.decode_ctc(pred)
|
||||
return [self.decode_ctc(p) for p in pred]
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -97,6 +102,7 @@ def register_wav2vec2_finetune(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(only_letters("Hello, world!"))
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
|
||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
|
||||
w2v.get_debug_values(0,"")
|
||||
|
|
48
codes/trainer/eval/eval_wer.py
Normal file
48
codes/trainer/eval/eval_wer.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from datasets import load_metric
|
||||
|
||||
import torch
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from data import create_dataset, create_dataloader
|
||||
from models.asr.w2v_wrapper import only_letters
|
||||
from models.tacotron2.text import sequence_to_text
|
||||
|
||||
|
||||
class WerEvaluator(evaluator.Evaluator):
|
||||
"""
|
||||
Evaluator that produces the WER for a speech recognition model on a test set.
|
||||
"""
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||
self.clip_key = opt_eval['clip_key']
|
||||
self.clip_lengths_key = opt_eval['clip_lengths_key']
|
||||
self.text_seq_key = opt_eval['text_seq_key']
|
||||
self.text_seq_lengths_key = opt_eval['text_seq_lengths_key']
|
||||
self.wer_metric = load_metric('wer')
|
||||
|
||||
def perform_eval(self):
|
||||
val_opt = deepcopy(self.env['opt']['datasets']['val'])
|
||||
val_opt['batch_size'] = 1 # This is important to ensure no padding.
|
||||
val_dataset, collate_fn = create_dataset(val_opt, return_collate=True)
|
||||
val_loader = create_dataloader(val_dataset, val_opt, self.env['opt'], None, collate_fn=collate_fn)
|
||||
model = self.model.module if hasattr(self.model, 'module') else self.model # Unwrap DDP models
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
preds = []
|
||||
reals = []
|
||||
for batch in val_loader:
|
||||
clip = batch[self.clip_key]
|
||||
assert clip.shape[0] == 1
|
||||
clip_len = batch[self.clip_lengths_key][0]
|
||||
clip = clip[:, :, :clip_len].cuda()
|
||||
pred_seq = model.inference(clip)
|
||||
preds.append(only_letters(sequence_to_text(pred_seq[0])))
|
||||
real_seq = batch[self.text_seq_key]
|
||||
real_seq_len = batch[self.text_seq_lengths_key][0]
|
||||
real_seq = real_seq[:, :real_seq_len]
|
||||
reals.append(only_letters(sequence_to_text(real_seq[0])))
|
||||
wer = self.wer_metric.compute(predictions=preds, references=reals)
|
||||
model.train()
|
||||
return {'eval_wer': wer}
|
||||
|
Loading…
Reference in New Issue
Block a user