wer update
This commit is contained in:
parent
f0c4cd6317
commit
17fb934575
|
@ -13,13 +13,35 @@ from models.tacotron2.taco_utils import load_filepaths_and_text
|
|||
from models.tacotron2.text.cleaners import english_cleaners
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
replacement_punctuation = {
|
||||
'{': '(', '}': ')',
|
||||
'[': '(', ']': ')',
|
||||
'`': '\'', '—': '-',
|
||||
'—': '-', '`': '\'',
|
||||
'ʼ': '\''
|
||||
}
|
||||
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||
|
||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||
extraneous = re.compile(r'^[@#%_=\$\^&\*\+\\]$')
|
||||
word = extraneous.sub('', word)
|
||||
return word
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file):
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def encode(self, txt):
|
||||
def preprocess_text(self, txt):
|
||||
txt = english_cleaners(txt)
|
||||
txt = remove_extraneous_punctuation(txt)
|
||||
return txt
|
||||
|
||||
def encode(self, txt):
|
||||
txt = self.preprocess_text(txt)
|
||||
txt = txt.replace(' ', '[SPACE]')
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
|
@ -28,6 +50,9 @@ class VoiceBpeTokenizer:
|
|||
seq = seq.cpu().numpy()
|
||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
||||
txt = txt.replace('[SPACE]', ' ')
|
||||
txt = txt.replace('[STOP]', '')
|
||||
txt = txt.replace('[UNK]', '')
|
||||
|
||||
return txt
|
||||
|
||||
|
||||
|
@ -50,23 +75,6 @@ def build_text_file_from_priors(priors, output):
|
|||
out.flush()
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
replacement_punctuation = {
|
||||
'{': '(', '}': ')',
|
||||
'[': '(', ']': ')',
|
||||
'`': '\'', '—': '-',
|
||||
'—': '-', '`': '\'',
|
||||
'ʼ': '\''
|
||||
}
|
||||
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||
|
||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||
extraneous = re.compile(r'^[@#%_=\$\^&\*\+\\]$')
|
||||
word = extraneous.sub('', word)
|
||||
return word
|
||||
|
||||
|
||||
def train():
|
||||
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
||||
ttsd = at.readlines()
|
||||
|
|
|
@ -5,6 +5,7 @@ import Levenshtein as Lev
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
from models.tacotron2.text import cleaners
|
||||
|
||||
|
||||
|
@ -56,9 +57,27 @@ class WordErrorRate:
|
|||
return Lev.distance(''.join(w1), ''.join(w2))
|
||||
|
||||
|
||||
def load_truths(file):
|
||||
niltok = VoiceBpeTokenizer(None)
|
||||
out = {}
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
for line in f.readline():
|
||||
spl = line.split('|')
|
||||
if len(spl) != 2:
|
||||
continue
|
||||
path, truth = spl
|
||||
path = path.replace('wav/', '')
|
||||
truth = niltok.preprocess_text(truth) # This may or may not be considered a "cheat", but the model is only trained on preprocessed text.
|
||||
out[path] = truth
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference_tsv = 'D:\\dlas\\codes\\results.tsv'
|
||||
libri_base = 'Z:\\libritts\\test-clean'
|
||||
inference_tsv = 'results.tsv'
|
||||
libri_base = '/h/bigasr_dataset/librispeech/test_clean/test_clean.txt'
|
||||
|
||||
# Pre-process truth values
|
||||
truths = load_truths(libri_base)
|
||||
|
||||
wer = WordErrorRate()
|
||||
wer_scores = []
|
||||
|
@ -67,13 +86,6 @@ if __name__ == '__main__':
|
|||
for line in tqdm(tsv):
|
||||
sentence_pred, wav = line.split('\t')
|
||||
sentence_pred = normalize_text(sentence_pred)
|
||||
|
||||
wav_comp = wav.split('_')
|
||||
reader = wav_comp[0]
|
||||
book = wav_comp[1]
|
||||
txt_file = os.path.join(libri_base, reader, book, wav.replace('.wav', '.normalized.txt'))
|
||||
with open(txt_file, 'r') as txt_file_hndl:
|
||||
txt_uncleaned = txt_file_hndl.read()
|
||||
sentence_real = normalize_text(clean_text(txt_uncleaned))
|
||||
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred))
|
||||
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}")
|
||||
sentence_real = normalize_text(truths[wav])
|
||||
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred))
|
||||
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}")
|
||||
|
|
Loading…
Reference in New Issue
Block a user