From 9090c34f10a1bdf5284f7d390e9d593a4354ff7d Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Dec 2024 22:47:12 -0600 Subject: [PATCH] cringe script to process seed-tts-eval's eval dataset into something i can easily use --- scripts/process_seed-tts.py | 53 +++++++++++++++++++++++++++++++++++++ vall_e/demo.py | 52 +++++++++++++++++++++++------------- 2 files changed, 86 insertions(+), 19 deletions(-) create mode 100644 scripts/process_seed-tts.py diff --git a/scripts/process_seed-tts.py b/scripts/process_seed-tts.py new file mode 100644 index 0000000..2671ddc --- /dev/null +++ b/scripts/process_seed-tts.py @@ -0,0 +1,53 @@ +""" +Handles processing seed-tts-eval's dataset into something to be used for vall_e.demo + +Reads from meta.lst, a text file where each utterance is formatted as: + +||| +""" + +import os +import json +import argparse +import torch +import shutil +import torchaudio +import numpy as np + +from tqdm.auto import tqdm +from pathlib import Path + +def process( + input_dir=Path("./seedtts_testset/en/"), + list_name="./meta.lst", + wav_dir="./wavs/", + output_dir=Path("./dataset/seed-tts-eval-en/"), +): + language = "auto" + + if "en" in str(input_dir): + language = "en" + elif "zh" in str(input_dir): + language = "zh" + + output_dir.mkdir(parents=True, exist_ok=True) + + # read manifest + lines = open(input_dir / list_name).read() + lines = lines.split("\n") + # split it even further + for line in lines: + if not line: + continue + speaker, text, prompt_wav, prompt_transcription = line.split("|") + + (output_dir / speaker).mkdir(parents=True, exist_ok=True) + + open( output_dir / speaker / "prompt.txt", "w", encoding="utf-8" ).write( text ) + open( output_dir / speaker / "language.txt", "w", encoding="utf-8" ).write( language ) + + shutil.copy((input_dir / wav_dir / speaker).with_suffix(".wav"), output_dir / speaker / "reference.wav" ) + shutil.copy(input_dir / prompt_wav, output_dir / speaker / "prompt.wav" ) + +if __name__ == "__main__": + process() \ No newline at end of file diff --git a/vall_e/demo.py b/vall_e/demo.py index ca72a56..62fc923 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -80,7 +80,7 @@ def main(): parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") - parser.add_argument("--dataset-dir-name", type=str, default="dataset") + parser.add_argument("--dataset-dir-name", type=str, default="") parser.add_argument("--dataset-dir-name-prefix", type=str, default=None) parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--skip-loading-dataloader", action="store_true") @@ -152,13 +152,20 @@ def main(): args.demo_dir = Path("./data/demo/") if not args.preamble: - args.preamble = "
".join([ - 'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.', - 'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.', - f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)', - 'Total WER: ${WER}
' - 'Total CER: ${CER}
' - 'Total SIM-O: ${SIM-O}
' + args.preamble = "\n".join([ + "Past model demo pages: [ar+nar-len-llama-8 (ar+nar)] [ar+nar-len-llama-8 (nar-len)] [ar+nar-llama-8 (ar+nar)] | Old demo pages: [1] [2] [3] [4] [5]", + "
", + "
", + "Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.", + "
", + "Objective metrics are computed by:", + "
    ", + "
  • WER/CER: transcribing (openai/whisper-base) then comparing the un-normalized word error rate on the phonemized transcriptions.
  • ", + "
  • SIM-O: retrieving the speaker embeddings of the output audio and the input prompt, from a finetune of WavLM for speaker verification (microsoft/wavlm-large), and computing the cosine similarity between the embeddings.
  • ", + "
", + "Tables marked as \"Validation\" are speakers/samples not seen to the model.", + "
", + f"These samples were generated using
--ar-temperature={args.ar_temperature} --nar-temperature={args.nar_temperature} --cfg-strength={args.cfg_strength} --max-steps={args.max_steps} --top-k={args.top_k} --dtype={args.dtype}
", ]) # comparison kwargs @@ -281,18 +288,21 @@ def main(): samples_dirs = {} # only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out - if "validation" in args.dataset_dir_name: - sample_dir["librispeech"] = args.demo_dir / "librispeech", + if not args.dataset_dir_name: + samples_dirs["librispeech"] = args.demo_dir / "librispeech" + else: + if "validation" in args.dataset_dir_name: + samples_dirs["librispeech"] = args.demo_dir / "librispeech" - # automatically pull from anything under the dataset dir - if args.dataset_dir_name.endswith("/*"): - args.dataset_dir_name = args.dataset_dir_name[:-2] - datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ] - for path in datasets: - samples_dirs[path.name] = path - # user provided dataset - elif (args.demo_dir / args.dataset_dir_name).exists(): - samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name + # automatically pull from anything under the dataset dir + if args.dataset_dir_name.endswith("/*"): + args.dataset_dir_name = args.dataset_dir_name[:-2] + datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ] + for path in datasets: + samples_dirs[path.name] = path + # user provided dataset + elif (args.demo_dir / args.dataset_dir_name).exists(): + samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name # pull from dataset samples if args.sample_from_dataset: @@ -354,6 +364,10 @@ def main(): # generate demo output for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"): + # bail if too many samples + if args.dataset_samples and len(samples) >= args.dataset_samples: + break + text = open(dir / "prompt.txt", encoding="utf-8").read() language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en" prompt = dir / "prompt.wav"