eval: integrate a n-gram language model into decoding

This commit is contained in:
James Betker 2022-02-21 19:12:34 -07:00
parent af50afe222
commit 6313a94f96

View File

@ -9,15 +9,15 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import trainer.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models.asr.w2v_wrapper import only_letters, Wav2VecWrapper from models.asr.w2v_wrapper import only_letters, Wav2VecWrapper
from models.tacotron2.text import sequence_to_text from models.tacotron2.text import sequence_to_text, tacotron_symbols
from pyctcdecode import build_ctcdecoder
# Librispeech: # Librispeech:
# baseline: .045% WER. # baseline: 4.5% WER.
# fine-tuned new head (0): .054% WER # fine-tuned new head (0): 5.4% WER
# # train_wav2vec_mass_large/models/13250_wav2vec.pth: 3.05% WER
# baseline: .328 # train_wav2vec_mass_large/models/13250_wav2vec.pth with kenlm: 3.34% WER
# 0: .342 from utils.util import opt_get
# 24000: .346
def tacotron_detokenize(seq): def tacotron_detokenize(seq):
@ -32,6 +32,23 @@ def fb_detokenize(seq):
return fb_processor.decode(seq) return fb_processor.decode(seq)
def perform_lm_processing(logits, decoder):
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_MIN_TOKEN_LOGP,
DEFAULT_PRUNE_LOGP,
)
assert len(logits.shape) == 3 and logits.shape[0] == 1
decoded_beams = decoder.decode_beams(
logits[0].cpu().numpy(),
beam_width=DEFAULT_BEAM_WIDTH,
beam_prune_logp=DEFAULT_PRUNE_LOGP,
token_min_logp=DEFAULT_MIN_TOKEN_LOGP
)
text = decoded_beams[0][0]
return only_letters(text.upper())
class WerEvaluator(evaluator.Evaluator): class WerEvaluator(evaluator.Evaluator):
""" """
Evaluator that produces the WER for a speech recognition model on a test set. Evaluator that produces the WER for a speech recognition model on a test set.
@ -45,6 +62,10 @@ class WerEvaluator(evaluator.Evaluator):
self.wer_metric = load_metric('wer') self.wer_metric = load_metric('wer')
self.detokenizer_fn = detokenizer_fn self.detokenizer_fn = detokenizer_fn
self.kenlm_model_path = opt_get(opt_eval, ['kenlm_path'], None)
if self.kenlm_model_path is not None:
self.kenlm_decoder = build_ctcdecoder(labels=tacotron_symbols(), kenlm_model_path=self.kenlm_model_path)
def perform_eval(self): def perform_eval(self):
val_opt = deepcopy(self.env['opt']['datasets']['val']) val_opt = deepcopy(self.env['opt']['datasets']['val'])
val_opt['batch_size'] = 1 # This is important to ensure no padding. val_opt['batch_size'] = 1 # This is important to ensure no padding.
@ -68,8 +89,14 @@ class WerEvaluator(evaluator.Evaluator):
continue # The WER computer doesn't like this scenario. continue # The WER computer doesn't like this scenario.
clip_len = batch[self.clip_lengths_key][0] clip_len = batch[self.clip_lengths_key][0]
clip = clip[:, :, :clip_len].cuda() clip = clip[:, :, :clip_len].cuda()
pred_seq = model.inference(clip) logits = model.inference_logits(clip)
preds.append(self.detokenizer_fn(pred_seq[0])) if self.kenlm_model_path is not None:
pred = perform_lm_processing(logits, self.kenlm_decoder)
else:
pred_seq = logits.argmax(dim=-1)
pred_seq = [model.decode_ctc(p) for p in pred_seq]
pred = self.detokenizer_fn(pred_seq[0])
preds.append(pred)
wer = self.wer_metric.compute(predictions=preds, references=reals) wer = self.wer_metric.compute(predictions=preds, references=reals)
model.train() model.train()
return {'eval_wer': wer} return {'eval_wer': wer}
@ -84,8 +111,10 @@ if __name__ == '__main__':
'batch_size': 1, 'batch_size': 1,
'mode': 'paired_voice_audio', 'mode': 'paired_voice_audio',
'sample_rate': 16000, 'sample_rate': 16000,
'path': ['y:/bigasr_dataset/mozcv/en/test.tsv'], 'path': ['y:/bigasr_dataset/librispeech/test_clean/test_clean.txt'],
'fetcher_mode': ['mozilla_cv'], 'fetcher_mode': ['libritts'],
#'path': ['y:/bigasr_dataset/mozcv/en/test.tsv'],
#'fetcher_mode': ['mozilla_cv'],
'max_wav_length': 200000, 'max_wav_length': 200000,
'use_bpe_tokenizer': False, 'use_bpe_tokenizer': False,
'max_text_length': 400, 'max_text_length': 400,
@ -99,9 +128,10 @@ if __name__ == '__main__':
'clip_lengths_key': 'wav_lengths', 'clip_lengths_key': 'wav_lengths',
'text_seq_key': 'padded_text', 'text_seq_key': 'padded_text',
'text_seq_lengths_key': 'text_lengths', 'text_seq_lengths_key': 'text_lengths',
'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin'
} }
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False) model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth') weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_large/models/13250_wav2vec.pth')
model.load_state_dict(weights) model.load_state_dict(weights)
model = model.cuda() model = model.cuda()
eval = WerEvaluator(model, opt_eval, env) eval = WerEvaluator(model, opt_eval, env)