cringe script to process seed-tts-eval's eval dataset into something i can easily use

This commit is contained in:
mrq 2024-12-17 22:47:12 -06:00
parent ed152f78df
commit 9090c34f10
2 changed files with 86 additions and 19 deletions

View File

@ -0,0 +1,53 @@
"""
Handles processing seed-tts-eval's dataset into something to be used for vall_e.demo
Reads from meta.lst, a text file where each utterance is formatted as:
<reference path>|<reference text>|<prompt path>|<prompt text>
"""
import os
import json
import argparse
import torch
import shutil
import torchaudio
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
def process(
input_dir=Path("./seedtts_testset/en/"),
list_name="./meta.lst",
wav_dir="./wavs/",
output_dir=Path("./dataset/seed-tts-eval-en/"),
):
language = "auto"
if "en" in str(input_dir):
language = "en"
elif "zh" in str(input_dir):
language = "zh"
output_dir.mkdir(parents=True, exist_ok=True)
# read manifest
lines = open(input_dir / list_name).read()
lines = lines.split("\n")
# split it even further
for line in lines:
if not line:
continue
speaker, text, prompt_wav, prompt_transcription = line.split("|")
(output_dir / speaker).mkdir(parents=True, exist_ok=True)
open( output_dir / speaker / "prompt.txt", "w", encoding="utf-8" ).write( text )
open( output_dir / speaker / "language.txt", "w", encoding="utf-8" ).write( language )
shutil.copy((input_dir / wav_dir / speaker).with_suffix(".wav"), output_dir / speaker / "reference.wav" )
shutil.copy(input_dir / prompt_wav, output_dir / speaker / "prompt.wav" )
if __name__ == "__main__":
process()

View File

@ -80,7 +80,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
parser.add_argument("--dataset-dir-name", type=str, default="")
parser.add_argument("--dataset-dir-name-prefix", type=str, default=None)
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--skip-loading-dataloader", action="store_true")
@ -152,13 +152,20 @@ def main():
args.demo_dir = Path("./data/demo/")
if not args.preamble:
args.preamble = "<br>".join([
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)',
'<b>Total WER:</b> ${WER}<br>'
'<b>Total CER:</b> ${CER}<br>'
'<b>Total SIM-O:</b> ${SIM-O}<br>'
args.preamble = "\n".join([
"Past model demo pages: <a href=\"/models/ar+nar-len-llama-8 (ar+nar).html\">[ar+nar-len-llama-8 (ar+nar)]</a> <a href=\"/models/ar+nar-len-llama-8 (nar-len).html\">[ar+nar-len-llama-8 (nar-len)]</a> <a href=\"/models/ar+nar-llama-8 (ar+nar).html\">[ar+nar-llama-8 (ar+nar)]</a> | Old demo pages: <a href=\"/loras/index.html\">[1]</a> <a href=\"/loras.html\">[2]</a> <a href=\"/old/2024.10.25.html\">[3]</a> <a href=\"/old/2024.12.15.html\">[4]</a> <a href=\"/old/2024.12.16.html\">[5]</a>",
"<br>",
"<br>",
"Below are some samples from my VALL-E implementation: <a href=\"https://git.ecker.tech/mrq/vall-e/\">https://git.ecker.tech/mrq/vall-e/</a>.",
"<br>",
"Objective metrics are computed by:",
"<ul>",
" <li>WER/CER: transcribing (openai/whisper-base) then comparing the un-normalized word error rate on the phonemized transcriptions.</li>",
" <li>SIM-O: retrieving the speaker embeddings of the output audio and the input prompt, from a finetune of WavLM for speaker verification (microsoft/wavlm-large), and computing the cosine similarity between the embeddings.</li>",
"</ul>",
"Tables marked as \"Validation\" are speakers/samples not seen to the model.",
"<br>",
f"These samples were generated using <pre>--ar-temperature={args.ar_temperature} --nar-temperature={args.nar_temperature} --cfg-strength={args.cfg_strength} --max-steps={args.max_steps} --top-k={args.top_k} --dtype={args.dtype}</pre>",
])
# comparison kwargs
@ -281,18 +288,21 @@ def main():
samples_dirs = {}
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
if "validation" in args.dataset_dir_name:
sample_dir["librispeech"] = args.demo_dir / "librispeech",
if not args.dataset_dir_name:
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
else:
if "validation" in args.dataset_dir_name:
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
# automatically pull from anything under the dataset dir
if args.dataset_dir_name.endswith("/*"):
args.dataset_dir_name = args.dataset_dir_name[:-2]
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
for path in datasets:
samples_dirs[path.name] = path
# user provided dataset
elif (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# automatically pull from anything under the dataset dir
if args.dataset_dir_name.endswith("/*"):
args.dataset_dir_name = args.dataset_dir_name[:-2]
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
for path in datasets:
samples_dirs[path.name] = path
# user provided dataset
elif (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples
if args.sample_from_dataset:
@ -354,6 +364,10 @@ def main():
# generate demo output
for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"):
# bail if too many samples
if args.dataset_samples and len(samples) >= args.dataset_samples:
break
text = open(dir / "prompt.txt", encoding="utf-8").read()
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
prompt = dir / "prompt.wav"