diff --git a/codes/trainer/eval/eval_wer.py b/codes/trainer/eval/eval_wer.py index 1d89b399..7b91a78b 100644 --- a/codes/trainer/eval/eval_wer.py +++ b/codes/trainer/eval/eval_wer.py @@ -9,15 +9,15 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import trainer.eval.evaluator as evaluator from data import create_dataset, create_dataloader from models.asr.w2v_wrapper import only_letters, Wav2VecWrapper -from models.tacotron2.text import sequence_to_text +from models.tacotron2.text import sequence_to_text, tacotron_symbols +from pyctcdecode import build_ctcdecoder # Librispeech: -# baseline: .045% WER. -# fine-tuned new head (0): .054% WER -# -# baseline: .328 -# 0: .342 -# 24000: .346 +# baseline: 4.5% WER. +# fine-tuned new head (0): 5.4% WER +# train_wav2vec_mass_large/models/13250_wav2vec.pth: 3.05% WER +# train_wav2vec_mass_large/models/13250_wav2vec.pth with kenlm: 3.34% WER +from utils.util import opt_get def tacotron_detokenize(seq): @@ -32,6 +32,23 @@ def fb_detokenize(seq): return fb_processor.decode(seq) +def perform_lm_processing(logits, decoder): + from pyctcdecode.constants import ( + DEFAULT_BEAM_WIDTH, + DEFAULT_MIN_TOKEN_LOGP, + DEFAULT_PRUNE_LOGP, + ) + + assert len(logits.shape) == 3 and logits.shape[0] == 1 + decoded_beams = decoder.decode_beams( + logits[0].cpu().numpy(), + beam_width=DEFAULT_BEAM_WIDTH, + beam_prune_logp=DEFAULT_PRUNE_LOGP, + token_min_logp=DEFAULT_MIN_TOKEN_LOGP + ) + text = decoded_beams[0][0] + return only_letters(text.upper()) + class WerEvaluator(evaluator.Evaluator): """ Evaluator that produces the WER for a speech recognition model on a test set. @@ -45,6 +62,10 @@ class WerEvaluator(evaluator.Evaluator): self.wer_metric = load_metric('wer') self.detokenizer_fn = detokenizer_fn + self.kenlm_model_path = opt_get(opt_eval, ['kenlm_path'], None) + if self.kenlm_model_path is not None: + self.kenlm_decoder = build_ctcdecoder(labels=tacotron_symbols(), kenlm_model_path=self.kenlm_model_path) + def perform_eval(self): val_opt = deepcopy(self.env['opt']['datasets']['val']) val_opt['batch_size'] = 1 # This is important to ensure no padding. @@ -68,8 +89,14 @@ class WerEvaluator(evaluator.Evaluator): continue # The WER computer doesn't like this scenario. clip_len = batch[self.clip_lengths_key][0] clip = clip[:, :, :clip_len].cuda() - pred_seq = model.inference(clip) - preds.append(self.detokenizer_fn(pred_seq[0])) + logits = model.inference_logits(clip) + if self.kenlm_model_path is not None: + pred = perform_lm_processing(logits, self.kenlm_decoder) + else: + pred_seq = logits.argmax(dim=-1) + pred_seq = [model.decode_ctc(p) for p in pred_seq] + pred = self.detokenizer_fn(pred_seq[0]) + preds.append(pred) wer = self.wer_metric.compute(predictions=preds, references=reals) model.train() return {'eval_wer': wer} @@ -84,8 +111,10 @@ if __name__ == '__main__': 'batch_size': 1, 'mode': 'paired_voice_audio', 'sample_rate': 16000, - 'path': ['y:/bigasr_dataset/mozcv/en/test.tsv'], - 'fetcher_mode': ['mozilla_cv'], + 'path': ['y:/bigasr_dataset/librispeech/test_clean/test_clean.txt'], + 'fetcher_mode': ['libritts'], + #'path': ['y:/bigasr_dataset/mozcv/en/test.tsv'], + #'fetcher_mode': ['mozilla_cv'], 'max_wav_length': 200000, 'use_bpe_tokenizer': False, 'max_text_length': 400, @@ -99,9 +128,10 @@ if __name__ == '__main__': 'clip_lengths_key': 'wav_lengths', 'text_seq_key': 'padded_text', 'text_seq_lengths_key': 'text_lengths', + 'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin' } model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False) - weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth') + weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_large/models/13250_wav2vec.pth') model.load_state_dict(weights) model = model.cuda() eval = WerEvaluator(model, opt_eval, env)