update wer script

This commit is contained in:
James Betker 2022-01-13 17:08:49 -07:00
parent 009a1e8404
commit 87c83e4957

View File

@ -1,91 +1,43 @@
# Original source: https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/validation.py import Levenshtein
import os from jiwer import wer, compute_measures
import Levenshtein as Lev
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from data.audio.voice_tokenizer import VoiceBpeTokenizer from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.tacotron2.text import cleaners
def clean_text(text):
for name in ['english_cleaners']:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
# Converts text to all-uppercase and separates punctuation from words.
def normalize_text(text):
text = text.upper()
for punc in ['.', ',', ':', ';']:
text = text.replace(punc, f' {punc}')
return text.strip()
class WordErrorRate:
def calculate_metric(self, transcript, reference):
wer_inst = self.wer_calc(transcript, reference)
self.wer += wer_inst
self.n_tokens += len(reference.split())
def compute(self):
wer = float(self.wer) / self.n_tokens
return wer.item() * 100
def wer_calc(self, s1, s2):
"""
Computes the Word Error Rate, defined as the edit distance between the
two provided sentences after tokenizing to words.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
# build mapping of words to integers
b = set(s1.split() + s2.split())
word2char = dict(zip(b, range(len(b))))
# map the words to a char array (Levenshtein packages only accepts
# strings)
w1 = [chr(word2char[w]) for w in s1.split()]
w2 = [chr(word2char[w]) for w in s2.split()]
return Lev.distance(''.join(w1), ''.join(w2))
def load_truths(file): def load_truths(file):
niltok = VoiceBpeTokenizer(None) niltok = VoiceBpeTokenizer(None)
out = {} out = {}
with open(file, 'r', encoding='utf-8') as f: with open(file, 'r', encoding='utf-8') as f:
for line in f.readline(): for line in f.readlines():
spl = line.split('|') spl = line.split('|')
if len(spl) != 2: if len(spl) != 2:
print(spl)
continue continue
path, truth = spl path, truth = spl
path = path.replace('wav/', '') #path = path.replace('wav/', '')
truth = niltok.preprocess_text(truth) # This may or may not be considered a "cheat", but the model is only trained on preprocessed text. # This preprocesses the truth data in the same way that training data is processed: removing punctuation, all lowercase, removing unnecessary
# whitespace, and applying "english cleaners", which convert words like "mrs" to "missus" and such.
truth = niltok.preprocess_text(truth)
out[path] = truth out[path] = truth
return out return out
if __name__ == '__main__': if __name__ == '__main__':
inference_tsv = 'results.tsv' inference_tsv = 'results.tsv'
libri_base = '/h/bigasr_dataset/librispeech/test_clean/test_clean.txt' libri_base = 'y:\\bigasr_dataset/librispeech/test_clean/test_clean.txt'
# Pre-process truth values # Pre-process truth values
truths = load_truths(libri_base) truths = load_truths(libri_base)
wer = WordErrorRate() ground_truths = []
wer_scores = [] hypotheses = []
with open(inference_tsv, 'r') as tsv_file: with open(inference_tsv, 'r') as tsv_file:
tsv = tsv_file.read().splitlines() tsv = tsv_file.read().splitlines()
for line in tqdm(tsv): for line in tqdm(tsv):
sentence_pred, wav = line.split('\t') sentence_pred, wav = line.split('\t')
sentence_pred = normalize_text(sentence_pred) hypotheses.append(sentence_pred)
sentence_real = normalize_text(truths[wav]) ground_truths.append(truths[wav])
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred)) wer = wer(ground_truths, hypotheses)*100
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}") print(f"WER: {wer}")