From 4775edaa41a509e5b38f0da11b201ad6fd77b573 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 18 Dec 2024 19:58:53 -0600 Subject: [PATCH] added text cleaning/normalization for wer purposes but it amounts to nothing desu --- vall_e/data.py | 133 +++++++++++++++++++++++++++++++++++++++++++--- vall_e/demo.py | 5 +- vall_e/metrics.py | 15 ++++-- 3 files changed, 140 insertions(+), 13 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index e2243fb..ea066d8 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,12 +63,133 @@ def sentence_split( s, split_by="sentences", quote_placeholder="" ): sentences = nltk.sent_tokenize(s) return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ] -# to-do: improve upon this since it's kind of ass -# this might be better to live in emb.g2p -def normalize_text( s ): - s = s.lower() - s = re.sub(r'[^\w\s]', '', s) - return s +# normalization code borrowed from TorToiSe TTS +# (it's not perfect but it works) + +try: + from tokenizers.normalizers import Lowercase, NFD, StripAccents + + normalizer = tokenizers.normalizers.Sequence([Lowercase(), NFD(), StripAccents()]) +except Exception as e: + normalizer = None + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] +def normalize_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + +def _remove_commas(m): + return m.group(1).replace(',', '') + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + +# in case the current env does not have it installed, so I don't need it as a hard dependency +try: + import inflect + + _inflect = inflect.engine() + + def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') +except Exception as e: + _inflect = None + +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') +_whitespace_re = re.compile(r'\s+') +_end_punct_re = re.compile(r'[\.\?\!]$') +_aux_punct_re = re.compile(r'[,;:\?\.\!-]') + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + if _inflect is not None: + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + +# full will do aggressive normalization, perfect for WER/CER +# not full will do basic cleaning +def normalize_text(text, language="auto", full=True): + if full: + if normalizer is not None: + text = normalizer.normalize_str( text ) + else: + text = text.lower() + text = normalize_numbers(text) # expand numbers + text = normalize_abbreviations(text) # expand abbreviations + #text = re.sub(_end_punct_re, '', text) # collapse whitespace + text = re.sub(_aux_punct_re, '', text) # collapse whitespace + text = text.replace('"', '') # remove quotation marks + else: + text = normalize_numbers(text) # expand numbers + text = normalize_abbreviations(text) # expand abbreviations + text = re.sub(_whitespace_re, ' ', text) # collapse whitespace + + # to-do: other languages + return text @cache def get_random_prompts( validation=False, min_length=0, tokenized=False ): diff --git a/vall_e/demo.py b/vall_e/demo.py index 62fc923..092eb0d 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -135,7 +135,7 @@ def main(): parser.add_argument("--lora", action="store_true") parser.add_argument("--comparison", type=str, default=None) - parser.add_argument("--transcription-model", type=str, default="openai/whisper-base") + parser.add_argument("--transcription-model", type=str, default="openai/whisper-large-v3") parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-large") args = parser.parse_args() @@ -426,7 +426,8 @@ def main(): calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) if calculate: - wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) + wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=True ) + #wer_score, cer_score = wer( out_path, reference_path, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False ) sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} diff --git a/vall_e/metrics.py b/vall_e/metrics.py index 8e91e0f..3fef2aa 100644 --- a/vall_e/metrics.py +++ b/vall_e/metrics.py @@ -12,13 +12,19 @@ from pathlib import Path from torcheval.metrics.functional import word_error_rate from torchmetrics.functional.text import char_error_rate -def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ): +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=UserWarning) + +def wer( audio, reference, language="auto", phonemize=True, **transcription_kwargs ): if language == "auto": language = detect_language( reference ) transcription = transcribe( audio, language=language, align=False, **transcription_kwargs ) + if language == "auto": language = transcription["language"] + transcription = transcription["text"] # reference audio needs transcribing too @@ -29,13 +35,12 @@ def wer( audio, reference, language="auto", normalize=True, phonemize=True, **tr transcription = coerce_to_hiragana( transcription ) reference = coerce_to_hiragana( reference ) - if normalize: - transcription = normalize_text( transcription ) - reference = normalize_text( reference ) - if phonemize: transcription = encode( transcription, language=language ) reference = encode( reference, language=language ) + else: + transcription = normalize_text( transcription, language=language ) + reference = normalize_text( reference, language=language ) wer_score = word_error_rate([transcription], [reference]).item() # un-normalize