diff --git a/data/tongue_twisters.txt b/data/tongue_twisters.txt new file mode 100644 index 0000000..b30fe24 --- /dev/null +++ b/data/tongue_twisters.txt @@ -0,0 +1,23 @@ +Six sick hicks nick six slick bricks with picks and sticks. +Fresh French fried fly fritters. +Rory the warrior and Roger the worrier were reared wrongly in a rural brewery. +Which wrist watches are Swiss wrist watches? +Fred fed Ted bread and Ted fed Fred bread. +The 33 thieves thought that they thrilled the throne throughout Thursday. +You know New York, you need New York, you know you need unique New York. +Lesser leather never weathered wetter weather better. +The sixth sick sheikh’s sixth sheep’s sick. +A skunk sat on a stump and thunk the stump stunk, but the stump thunk the skunk stunk. +Thirty-three thirsty, thundering thoroughbreds thumped Mr. Thurber on Thursday. +Wayne went to wales to watch walruses. +Seventy-seven benevolent elephants. +Send toast to ten tense stout saints’ ten tall tents. +I slit the sheet, the sheet I slit, and on the slitted sheet I sit. +Give papa a cup of proper coffee in a copper coffee cup. +She sells seashells by the seashore. +Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick? +Pad kid poured curd pulled cod. +Fuzzy Wuzzy was a bear. Fuzzy Wuzzy had no hair. Fuzzy Wuzzy wasn’t very fuzzy, was he? +Supercalifragilisticexpialidocious. +How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood. +Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo. \ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index ea066d8..9fcfe98 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True): return text @cache -def get_random_prompts( validation=False, min_length=0, tokenized=False ): +def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range sentences = [ "The birch canoe slid on the smooth planks.", @@ -228,9 +228,8 @@ def get_random_prompts( validation=False, min_length=0, tokenized=False ): "Perfect. Please move quickly to the chamber lock, as the effect of prolonged exposure to the button are not part of this test.", ] - harvard_sentences_path = Path("./data/harvard_sentences.txt") - if harvard_sentences_path.exists(): - sentences = open( harvard_sentences_path, "r", encoding="utf-8" ).read().split("\n") + if source_path.exists(): + sentences = open( source_path, "r", encoding="utf-8" ).read().split("\n") # Pull from validation dataset if existing + requested if validation and cfg.dataset.validation: diff --git a/vall_e/demo.py b/vall_e/demo.py index 092eb0d..16ab412 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -426,15 +426,20 @@ def main(): calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) if calculate: - wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=True ) - #wer_score, cer_score = wer( out_path, reference_path, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False ) + # computes based on word transcriptions outright + wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False ) + # compute on words as well, but does not normalize + wer_un_score, cer_un_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False, normalize=False ) + # computes on phonemes instead + pwer_score, per_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=True ) + sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) - metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} + metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score, "per": per_score, "pwer": pwer_score, "wer_un": wer_un_score, "cer_un": cer_un_score } json_write( metrics, metrics_path ) else: metrics = json_read( metrics_path ) - wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"] + wer_score, cer_score, per_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["per"], metrics["sim-o"] if dataset_name not in metrics_map: metrics_map[dataset_name] = {} @@ -444,11 +449,11 @@ def main(): # collate entries into HTML tables = [] for dataset_name, samples in outputs: - table = "\t\t

${DATASET_NAME}

\n\t\t

Average WER: ${WER}
Average CER: ${CER}
Average SIM-O: ${SIM-O}

\n\t\t\n\t\t\t\n\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\n\t\t\t\n\t\t\t${SAMPLES}\n\t\t
TextWER↓CER↓SIM-O↑PromptOur VALL-EGround Truth
" + table = "\t\t

${DATASET_NAME}

\n\t\t

Average WER: ${WER}
Average CER: ${CER}
Average PER: ${PER}
Average SIM-O: ${SIM-O}

\n\t\t\n\t\t\t\n\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\n\t\t\t\n\t\t\t${SAMPLES}\n\t\t
TextWER↓CER↓SIM-O↑PromptOur VALL-EGround Truth
" samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ "".join([ - f'\n\t\t\t\t{metrics_map[dataset_name][audios[1]][0]:.3f}{metrics_map[dataset_name][audios[1]][1]:.3f}{metrics_map[dataset_name][audios[1]][2]:.3f}' + f'\n\t\t\t\t{metrics_map[dataset_name][audios[1]][0]:.3f}{metrics_map[dataset_name][audios[1]][1]:.3f}{metrics_map[dataset_name][audios[1]][2]:.3f}{metrics_map[dataset_name][audios[1]][3]:.3f}' ] ) + "".join( [ f'\n\t\t\t\t' @@ -461,7 +466,8 @@ def main(): # write audio into template table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) - table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) + table = table.replace("${PER}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) + table = table.replace("${SIM-O}", f'{mean([ metrics[3] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) table = table.replace("${DATASET_NAME}", dataset_name) table = table.replace("${SAMPLES}", "\n".join( samples ) ) diff --git a/vall_e/metrics.py b/vall_e/metrics.py index 3fef2aa..2aeb4b3 100644 --- a/vall_e/metrics.py +++ b/vall_e/metrics.py @@ -16,7 +16,7 @@ import warnings warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=UserWarning) -def wer( audio, reference, language="auto", phonemize=True, **transcription_kwargs ): +def wer( audio, reference, language="auto", phonemize=True, normalize=True, **transcription_kwargs ): if language == "auto": language = detect_language( reference ) @@ -38,7 +38,7 @@ def wer( audio, reference, language="auto", phonemize=True, **transcription_kwar if phonemize: transcription = encode( transcription, language=language ) reference = encode( reference, language=language ) - else: + elif normalize: transcription = normalize_text( transcription, language=language ) reference = normalize_text( reference, language=language )