wer update
This commit is contained in:
parent
f0c4cd6317
commit
17fb934575
|
@ -13,13 +13,35 @@ from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||||
from models.tacotron2.text.cleaners import english_cleaners
|
from models.tacotron2.text.cleaners import english_cleaners
|
||||||
|
|
||||||
|
|
||||||
|
def remove_extraneous_punctuation(word):
|
||||||
|
replacement_punctuation = {
|
||||||
|
'{': '(', '}': ')',
|
||||||
|
'[': '(', ']': ')',
|
||||||
|
'`': '\'', '—': '-',
|
||||||
|
'—': '-', '`': '\'',
|
||||||
|
'ʼ': '\''
|
||||||
|
}
|
||||||
|
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
||||||
|
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||||
|
|
||||||
|
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||||
|
extraneous = re.compile(r'^[@#%_=\$\^&\*\+\\]$')
|
||||||
|
word = extraneous.sub('', word)
|
||||||
|
return word
|
||||||
|
|
||||||
|
|
||||||
class VoiceBpeTokenizer:
|
class VoiceBpeTokenizer:
|
||||||
def __init__(self, vocab_file):
|
def __init__(self, vocab_file):
|
||||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
if vocab_file is not None:
|
||||||
|
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||||
|
|
||||||
def encode(self, txt):
|
def preprocess_text(self, txt):
|
||||||
txt = english_cleaners(txt)
|
txt = english_cleaners(txt)
|
||||||
txt = remove_extraneous_punctuation(txt)
|
txt = remove_extraneous_punctuation(txt)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
def encode(self, txt):
|
||||||
|
txt = self.preprocess_text(txt)
|
||||||
txt = txt.replace(' ', '[SPACE]')
|
txt = txt.replace(' ', '[SPACE]')
|
||||||
return self.tokenizer.encode(txt).ids
|
return self.tokenizer.encode(txt).ids
|
||||||
|
|
||||||
|
@ -28,6 +50,9 @@ class VoiceBpeTokenizer:
|
||||||
seq = seq.cpu().numpy()
|
seq = seq.cpu().numpy()
|
||||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
||||||
txt = txt.replace('[SPACE]', ' ')
|
txt = txt.replace('[SPACE]', ' ')
|
||||||
|
txt = txt.replace('[STOP]', '')
|
||||||
|
txt = txt.replace('[UNK]', '')
|
||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,23 +75,6 @@ def build_text_file_from_priors(priors, output):
|
||||||
out.flush()
|
out.flush()
|
||||||
|
|
||||||
|
|
||||||
def remove_extraneous_punctuation(word):
|
|
||||||
replacement_punctuation = {
|
|
||||||
'{': '(', '}': ')',
|
|
||||||
'[': '(', ']': ')',
|
|
||||||
'`': '\'', '—': '-',
|
|
||||||
'—': '-', '`': '\'',
|
|
||||||
'ʼ': '\''
|
|
||||||
}
|
|
||||||
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
|
||||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
|
||||||
|
|
||||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
|
||||||
extraneous = re.compile(r'^[@#%_=\$\^&\*\+\\]$')
|
|
||||||
word = extraneous.sub('', word)
|
|
||||||
return word
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
||||||
ttsd = at.readlines()
|
ttsd = at.readlines()
|
||||||
|
|
|
@ -5,6 +5,7 @@ import Levenshtein as Lev
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||||
from models.tacotron2.text import cleaners
|
from models.tacotron2.text import cleaners
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,9 +57,27 @@ class WordErrorRate:
|
||||||
return Lev.distance(''.join(w1), ''.join(w2))
|
return Lev.distance(''.join(w1), ''.join(w2))
|
||||||
|
|
||||||
|
|
||||||
|
def load_truths(file):
|
||||||
|
niltok = VoiceBpeTokenizer(None)
|
||||||
|
out = {}
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f.readline():
|
||||||
|
spl = line.split('|')
|
||||||
|
if len(spl) != 2:
|
||||||
|
continue
|
||||||
|
path, truth = spl
|
||||||
|
path = path.replace('wav/', '')
|
||||||
|
truth = niltok.preprocess_text(truth) # This may or may not be considered a "cheat", but the model is only trained on preprocessed text.
|
||||||
|
out[path] = truth
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
inference_tsv = 'D:\\dlas\\codes\\results.tsv'
|
inference_tsv = 'results.tsv'
|
||||||
libri_base = 'Z:\\libritts\\test-clean'
|
libri_base = '/h/bigasr_dataset/librispeech/test_clean/test_clean.txt'
|
||||||
|
|
||||||
|
# Pre-process truth values
|
||||||
|
truths = load_truths(libri_base)
|
||||||
|
|
||||||
wer = WordErrorRate()
|
wer = WordErrorRate()
|
||||||
wer_scores = []
|
wer_scores = []
|
||||||
|
@ -67,13 +86,6 @@ if __name__ == '__main__':
|
||||||
for line in tqdm(tsv):
|
for line in tqdm(tsv):
|
||||||
sentence_pred, wav = line.split('\t')
|
sentence_pred, wav = line.split('\t')
|
||||||
sentence_pred = normalize_text(sentence_pred)
|
sentence_pred = normalize_text(sentence_pred)
|
||||||
|
sentence_real = normalize_text(truths[wav])
|
||||||
wav_comp = wav.split('_')
|
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred))
|
||||||
reader = wav_comp[0]
|
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}")
|
||||||
book = wav_comp[1]
|
|
||||||
txt_file = os.path.join(libri_base, reader, book, wav.replace('.wav', '.normalized.txt'))
|
|
||||||
with open(txt_file, 'r') as txt_file_hndl:
|
|
||||||
txt_uncleaned = txt_file_hndl.read()
|
|
||||||
sentence_real = normalize_text(clean_text(txt_uncleaned))
|
|
||||||
wer_scores.append(wer.wer_calc(sentence_real, sentence_pred))
|
|
||||||
print(f"WER: {torch.tensor(wer_scores, dtype=torch.float).mean()}")
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user