diff --git a/data/tongue_twisters.txt b/data/tongue_twisters.txt new file mode 100644 index 0000000..b30fe24 --- /dev/null +++ b/data/tongue_twisters.txt @@ -0,0 +1,23 @@ +Six sick hicks nick six slick bricks with picks and sticks. +Fresh French fried fly fritters. +Rory the warrior and Roger the worrier were reared wrongly in a rural brewery. +Which wrist watches are Swiss wrist watches? +Fred fed Ted bread and Ted fed Fred bread. +The 33 thieves thought that they thrilled the throne throughout Thursday. +You know New York, you need New York, you know you need unique New York. +Lesser leather never weathered wetter weather better. +The sixth sick sheikh’s sixth sheep’s sick. +A skunk sat on a stump and thunk the stump stunk, but the stump thunk the skunk stunk. +Thirty-three thirsty, thundering thoroughbreds thumped Mr. Thurber on Thursday. +Wayne went to wales to watch walruses. +Seventy-seven benevolent elephants. +Send toast to ten tense stout saints’ ten tall tents. +I slit the sheet, the sheet I slit, and on the slitted sheet I sit. +Give papa a cup of proper coffee in a copper coffee cup. +She sells seashells by the seashore. +Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick? +Pad kid poured curd pulled cod. +Fuzzy Wuzzy was a bear. Fuzzy Wuzzy had no hair. Fuzzy Wuzzy wasn’t very fuzzy, was he? +Supercalifragilisticexpialidocious. +How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood. +Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo. \ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index ea066d8..9fcfe98 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True): return text @cache -def get_random_prompts( validation=False, min_length=0, tokenized=False ): +def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range sentences = [ "The birch canoe slid on the smooth planks.", @@ -228,9 +228,8 @@ def get_random_prompts( validation=False, min_length=0, tokenized=False ): "Perfect. Please move quickly to the chamber lock, as the effect of prolonged exposure to the button are not part of this test.", ] - harvard_sentences_path = Path("./data/harvard_sentences.txt") - if harvard_sentences_path.exists(): - sentences = open( harvard_sentences_path, "r", encoding="utf-8" ).read().split("\n") + if source_path.exists(): + sentences = open( source_path, "r", encoding="utf-8" ).read().split("\n") # Pull from validation dataset if existing + requested if validation and cfg.dataset.validation: diff --git a/vall_e/demo.py b/vall_e/demo.py index 092eb0d..16ab412 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -426,15 +426,20 @@ def main(): calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) if calculate: - wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=True ) - #wer_score, cer_score = wer( out_path, reference_path, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False ) + # computes based on word transcriptions outright + wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False ) + # compute on words as well, but does not normalize + wer_un_score, cer_un_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=False, normalize=False ) + # computes on phonemes instead + pwer_score, per_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model, phonemize=True ) + sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) - metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} + metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score, "per": per_score, "pwer": pwer_score, "wer_un": wer_un_score, "cer_un": cer_un_score } json_write( metrics, metrics_path ) else: metrics = json_read( metrics_path ) - wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"] + wer_score, cer_score, per_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["per"], metrics["sim-o"] if dataset_name not in metrics_map: metrics_map[dataset_name] = {} @@ -444,11 +449,11 @@ def main(): # collate entries into HTML tables = [] for dataset_name, samples in outputs: - table = "\t\t
Average WER: ${WER}
Average CER: ${CER}
Average SIM-O: ${SIM-O}
Text | \n\t\t\t\t\tWER↓ | \n\t\t\t\t\tCER↓ | \n\t\t\t\t\tSIM-O↑ | \n\t\t\t\t\tPrompt | \n\t\t\t\t\tOur VALL-E | \n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\tGround Truth | \n\t\t\t\t
---|
Average WER: ${WER}
Average CER: ${CER}
Average PER: ${PER}
Average SIM-O: ${SIM-O}
Text | \n\t\t\t\t\tWER↓ | \n\t\t\t\t\tCER↓ | \n\t\t\t\t\tSIM-O↑ | \n\t\t\t\t\tPrompt | \n\t\t\t\t\tOur VALL-E | \n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\tGround Truth | \n\t\t\t\t
---|