WER script
This commit is contained in:
parent
3a9d1c53ea
commit
15437b2fc3
79
codes/scripts/audio/word_error_rate.py
Normal file
79
codes/scripts/audio/word_error_rate.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/validation.py
|
||||||
|
import os
|
||||||
|
|
||||||
|
import Levenshtein as Lev
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models.tacotron2.text import cleaners
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text):
|
||||||
|
for name in ['english_cleaners']:
|
||||||
|
cleaner = getattr(cleaners, name)
|
||||||
|
if not cleaner:
|
||||||
|
raise Exception('Unknown cleaner: %s' % name)
|
||||||
|
text = cleaner(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Converts text to all-uppercase and separates punctuation from words.
|
||||||
|
def normalize_text(text):
|
||||||
|
text = text.upper()
|
||||||
|
for punc in ['.', ',', ':', ';']:
|
||||||
|
text = text.replace(punc, f' {punc}')
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class WordErrorRate:
|
||||||
|
def calculate_metric(self, transcript, reference):
|
||||||
|
wer_inst = self.wer_calc(transcript, reference)
|
||||||
|
self.wer += wer_inst
|
||||||
|
self.n_tokens += len(reference.split())
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
wer = float(self.wer) / self.n_tokens
|
||||||
|
return wer.item() * 100
|
||||||
|
|
||||||
|
def wer_calc(self, s1, s2):
|
||||||
|
"""
|
||||||
|
Computes the Word Error Rate, defined as the edit distance between the
|
||||||
|
two provided sentences after tokenizing to words.
|
||||||
|
Arguments:
|
||||||
|
s1 (string): space-separated sentence
|
||||||
|
s2 (string): space-separated sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
# build mapping of words to integers
|
||||||
|
b = set(s1.split() + s2.split())
|
||||||
|
word2char = dict(zip(b, range(len(b))))
|
||||||
|
|
||||||
|
# map the words to a char array (Levenshtein packages only accepts
|
||||||
|
# strings)
|
||||||
|
w1 = [chr(word2char[w]) for w in s1.split()]
|
||||||
|
w2 = [chr(word2char[w]) for w in s2.split()]
|
||||||
|
|
||||||
|
return Lev.distance(''.join(w1), ''.join(w2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
inference_tsv = '\\\\192.168.5.3\\rtx3080_drv\\dlas\\codes\\eval_libritts_for_gpt_asr_results_WER=2.6615.tsv'
|
||||||
|
libri_base = 'Z:\\libritts\\test-clean'
|
||||||
|
|
||||||
|
wer = WordErrorRate()
|
||||||
|
wer_scores = []
|
||||||
|
with open(inference_tsv, 'r') as tsv_file:
|
||||||
|
tsv = tsv_file.read().splitlines()
|
||||||
|
for line in tqdm(tsv):
|
||||||
|
sentence_pred, wav = line.split('\t')
|
||||||
|
sentence_pred = normalize_text(sentence_pred)
|
||||||
|
|
||||||
|
wav_comp = wav.split('_')
|
||||||
|
reader = wav_comp[0]
|
||||||
|
book = wav_comp[1]
|
||||||
|
txt_file = os.path.join(libri_base, reader, book, wav.replace('.wav', '.normalized.txt'))
|
||||||
|
with open(txt_file, 'r') as txt_file_hndl:
|
||||||
|
txt_uncleaned = txt_file_hndl.read()
|
||||||
|
sentence_real = normalize_text(clean_text(txt_uncleaned))
|
||||||
|
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred))
|
||||||
|
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}")
|
Loading…
Reference in New Issue
Block a user