forked from mrq/DL-Art-School
Add & refine WER evaluator for w2v
This commit is contained in:
parent
e16af944c0
commit
a4f1641eea
|
@ -11,6 +11,11 @@ from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def only_letters(string):
|
||||||
|
allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'')
|
||||||
|
return ''.join(filter(allowlist.__contains__, string.upper()))
|
||||||
|
|
||||||
|
|
||||||
class Wav2VecWrapper(nn.Module):
|
class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||||
|
@ -77,8 +82,8 @@ class Wav2VecWrapper(nn.Module):
|
||||||
pred_strings = []
|
pred_strings = []
|
||||||
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
|
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
|
||||||
last_labels[last_labels == -100] = 0
|
last_labels[last_labels == -100] = 0
|
||||||
label_strings.extend([sequence_to_text(lbl) for lbl in last_labels])
|
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
|
||||||
pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred])
|
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
|
||||||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||||
res['wer'] = wer
|
res['wer'] = wer
|
||||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||||
|
@ -88,7 +93,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
||||||
pred = logits.argmax(dim=-1)
|
pred = logits.argmax(dim=-1)
|
||||||
return self.decode_ctc(pred)
|
return [self.decode_ctc(p) for p in pred]
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
@ -97,6 +102,7 @@ def register_wav2vec2_finetune(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
print(only_letters("Hello, world!"))
|
||||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
|
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
|
||||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
|
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
|
||||||
w2v.get_debug_values(0,"")
|
w2v.get_debug_values(0,"")
|
||||||
|
|
48
codes/trainer/eval/eval_wer.py
Normal file
48
codes/trainer/eval/eval_wer.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from datasets import load_metric
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import trainer.eval.evaluator as evaluator
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
from models.asr.w2v_wrapper import only_letters
|
||||||
|
from models.tacotron2.text import sequence_to_text
|
||||||
|
|
||||||
|
|
||||||
|
class WerEvaluator(evaluator.Evaluator):
|
||||||
|
"""
|
||||||
|
Evaluator that produces the WER for a speech recognition model on a test set.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, opt_eval, env):
|
||||||
|
super().__init__(model, opt_eval, env, uses_all_ddp=False)
|
||||||
|
self.clip_key = opt_eval['clip_key']
|
||||||
|
self.clip_lengths_key = opt_eval['clip_lengths_key']
|
||||||
|
self.text_seq_key = opt_eval['text_seq_key']
|
||||||
|
self.text_seq_lengths_key = opt_eval['text_seq_lengths_key']
|
||||||
|
self.wer_metric = load_metric('wer')
|
||||||
|
|
||||||
|
def perform_eval(self):
|
||||||
|
val_opt = deepcopy(self.env['opt']['datasets']['val'])
|
||||||
|
val_opt['batch_size'] = 1 # This is important to ensure no padding.
|
||||||
|
val_dataset, collate_fn = create_dataset(val_opt, return_collate=True)
|
||||||
|
val_loader = create_dataloader(val_dataset, val_opt, self.env['opt'], None, collate_fn=collate_fn)
|
||||||
|
model = self.model.module if hasattr(self.model, 'module') else self.model # Unwrap DDP models
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = []
|
||||||
|
reals = []
|
||||||
|
for batch in val_loader:
|
||||||
|
clip = batch[self.clip_key]
|
||||||
|
assert clip.shape[0] == 1
|
||||||
|
clip_len = batch[self.clip_lengths_key][0]
|
||||||
|
clip = clip[:, :, :clip_len].cuda()
|
||||||
|
pred_seq = model.inference(clip)
|
||||||
|
preds.append(only_letters(sequence_to_text(pred_seq[0])))
|
||||||
|
real_seq = batch[self.text_seq_key]
|
||||||
|
real_seq_len = batch[self.text_seq_lengths_key][0]
|
||||||
|
real_seq = real_seq[:, :real_seq_len]
|
||||||
|
reals.append(only_letters(sequence_to_text(real_seq[0])))
|
||||||
|
wer = self.wer_metric.compute(predictions=preds, references=reals)
|
||||||
|
model.train()
|
||||||
|
return {'eval_wer': wer}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user