From 15437b2fc33f7f7cfe7f8f277b4c3f993e75dd78 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 26 Oct 2021 13:30:29 -0600 Subject: [PATCH] WER script --- codes/scripts/audio/word_error_rate.py | 79 ++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 codes/scripts/audio/word_error_rate.py diff --git a/codes/scripts/audio/word_error_rate.py b/codes/scripts/audio/word_error_rate.py new file mode 100644 index 00000000..bcc257a8 --- /dev/null +++ b/codes/scripts/audio/word_error_rate.py @@ -0,0 +1,79 @@ +# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/validation.py +import os + +import Levenshtein as Lev +import torch +from tqdm import tqdm + +from models.tacotron2.text import cleaners + + +def clean_text(text): + for name in ['english_cleaners']: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + +# Converts text to all-uppercase and separates punctuation from words. +def normalize_text(text): + text = text.upper() + for punc in ['.', ',', ':', ';']: + text = text.replace(punc, f' {punc}') + return text.strip() + + +class WordErrorRate: + def calculate_metric(self, transcript, reference): + wer_inst = self.wer_calc(transcript, reference) + self.wer += wer_inst + self.n_tokens += len(reference.split()) + + def compute(self): + wer = float(self.wer) / self.n_tokens + return wer.item() * 100 + + def wer_calc(self, s1, s2): + """ + Computes the Word Error Rate, defined as the edit distance between the + two provided sentences after tokenizing to words. + Arguments: + s1 (string): space-separated sentence + s2 (string): space-separated sentence + """ + + # build mapping of words to integers + b = set(s1.split() + s2.split()) + word2char = dict(zip(b, range(len(b)))) + + # map the words to a char array (Levenshtein packages only accepts + # strings) + w1 = [chr(word2char[w]) for w in s1.split()] + w2 = [chr(word2char[w]) for w in s2.split()] + + return Lev.distance(''.join(w1), ''.join(w2)) + + +if __name__ == '__main__': + inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv' + libri_base = 'Z:\\libritts\\test-clean' + + wer = WordErrorRate() + wer_scores = [] + with open(inference_tsv, 'r') as tsv_file: + tsv = tsv_file.read().splitlines() + for line in tqdm(tsv): + sentence_pred, wav = line.split('\t') + sentence_pred = normalize_text(sentence_pred) + + wav_comp = wav.split('_') + reader = wav_comp[0] + book = wav_comp[1] + txt_file = os.path.join(libri_base, reader, book, wav.replace('.wav', '.normalized.txt')) + with open(txt_file, 'r') as txt_file_hndl: + txt_uncleaned = txt_file_hndl.read() + sentence_real = normalize_text(clean_text(txt_uncleaned)) + wer_scores.append(wer.wer_calc(sentence_real, sentence_pred)) + print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}") \ No newline at end of file