From ed152f78dff1b25c48049f49cb304543c2f80a0b Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Dec 2024 19:33:04 -0600 Subject: [PATCH] tweaks to prompt duration to allow me to divorce how i use it for training with how I'm using it for the demo page, and demo page tweaks to make my life easier --- data/demo/index.template.html | 34 +------------------ vall_e/data.py | 12 +++++-- vall_e/demo.py | 61 ++++++++++++++++++++++------------- vall_e/inference.py | 16 +++++---- 4 files changed, 57 insertions(+), 66 deletions(-) diff --git a/data/demo/index.template.html b/data/demo/index.template.html index 592788c..cbc1362 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -5,39 +5,7 @@

VALL-E Demo

${PREAMBLE}

- - - - - - - - - - - - - - - - ${LIBRISPEECH_SAMPLES} -
LibriSpeech
TextWER↓CER↓SIM-O↑PromptOur VALL-EOriginal VALL-EGround Truth
- - - - - - - - - - - - - - - ${DATASET_SAMPLES} -
Sampled Dataset
TextWER↓CER↓SIM-O↑PromptOur VALL-EGround Truth
+ ${TABLES}

Settings used:

${SETTINGS}

\ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index e88695c..e2243fb 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1058,7 +1058,7 @@ class Dataset(_Dataset): def sample_prompts(self, spkr_name, reference, should_trim=True): # return no prompt if explicitly requested for who knows why # or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them) - if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1: + if len(self.paths_by_spkr_name[spkr_name]) <= 1: return None prom_list = [] @@ -1075,9 +1075,15 @@ class Dataset(_Dataset): ) """ + if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0: + should_trim = False + prom_length = 0 - duration_lo, duration_hi = cfg.dataset.prompt_duration_range - trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0 + if should_trim: + duration_lo, duration_hi = cfg.dataset.prompt_duration_range + trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) + else: + trim_length = 0 for _ in range(cfg.dataset.prompt_max_samples): if reference is not None: diff --git a/vall_e/demo.py b/vall_e/demo.py index 96f9551..ca72a56 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -278,16 +278,24 @@ def main(): html = html.replace(r"${SETTINGS}", str(sampling_kwargs)) # pull from provided samples - samples_dirs = { - "librispeech": args.demo_dir / "librispeech", - } + samples_dirs = {} - if (args.demo_dir / args.dataset_dir_name).exists(): + # only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out + if "validation" in args.dataset_dir_name: + sample_dir["librispeech"] = args.demo_dir / "librispeech", + + # automatically pull from anything under the dataset dir + if args.dataset_dir_name.endswith("/*"): + args.dataset_dir_name = args.dataset_dir_name[:-2] + datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ] + for path in datasets: + samples_dirs[path.name] = path + # user provided dataset + elif (args.demo_dir / args.dataset_dir_name).exists(): samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name # pull from dataset samples if args.sample_from_dataset: - cfg.dataset.cache = False cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker" cfg.dataset.sample_order = "random" cfg.dataset.tasks_list = [ 'tts' ] @@ -335,18 +343,17 @@ def main(): outputs = [] metrics_inputs = [] comparison_inputs = [] - for k, sample_dir in samples_dirs.items(): + for dataset_name, sample_dir in samples_dirs.items(): if not sample_dir.exists(): continue samples = [] speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] speakers.sort() - #sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"] - sources = [ "ms_valle" ] if k == "librispeech" else [] + sources = [ "ms_valle", "f5" ] if dataset_name == "librispeech" else [] # generate demo output - for dir in tqdm(speakers, desc=f"Generating demo for {k}"): + for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"): text = open(dir / "prompt.txt", encoding="utf-8").read() language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en" prompt = dir / "prompt.wav" @@ -368,7 +375,7 @@ def main(): cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"' """ - if not args.random_prompts or k == "librispeech": + if not args.random_prompts or dataset_name == "librispeech": audio_samples += [ reference ] samples.append(( @@ -383,16 +390,16 @@ def main(): if should_generate: comparison_inputs.append((text, prompt, language, out_path_comparison)) - metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path)) + metrics_inputs.append((dataset_name, text, language, out_path_comparison, prompt, reference, metrics_path)) should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) if should_generate: inputs.append((text, prompt, language, out_path)) - metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path)) + metrics_inputs.append((dataset_name, text, language, out_path, prompt, reference, metrics_path)) - outputs.append((k, samples)) + outputs.append((dataset_name, samples)) if inputs: process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) ) @@ -401,28 +408,32 @@ def main(): process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) metrics_map = {} - for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): + for dataset_name, text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) if calculate: wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) - #sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) - metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score} + metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} json_write( metrics, metrics_path ) else: metrics = json_read( metrics_path ) wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"] - metrics_map[out_path] = (wer_score, cer_score, sim_o_score) + if dataset_name not in metrics_map: + metrics_map[dataset_name] = {} + + metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score) # collate entries into HTML - for k, samples in outputs: + tables = [] + for dataset_name, samples in outputs: + table = "\t\t

${DATASET_NAME}

\n\t\t

Average WER: ${WER}
Average CER: ${CER}
Average SIM-O: ${SIM-O}

\n\t\t\n\t\t\t\n\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\t\n\t\t\t\t\n\t\t\t\n\t\t\t${SAMPLES}\n\t\t
TextWER↓CER↓SIM-O↑PromptOur VALL-EGround Truth
" samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ "".join([ - f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f}{metrics_map[audios[1]][1]:.3f}{metrics_map[audios[1]][2]:.3f}' + f'\n\t\t\t\t{metrics_map[dataset_name][audios[1]][0]:.3f}{metrics_map[dataset_name][audios[1]][1]:.3f}{metrics_map[dataset_name][audios[1]][2]:.3f}' ] ) + "".join( [ f'\n\t\t\t\t' @@ -433,11 +444,15 @@ def main(): ] # write audio into template - html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) + table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) + table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) + table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' ) + + table = table.replace("${DATASET_NAME}", dataset_name) + table = table.replace("${SAMPLES}", "\n".join( samples ) ) + tables.append( table ) - html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' ) - html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' ) - html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' ) + html = html.replace("${TABLES}", "\n".join( tables )) if args.comparison: disabled, enabled = comparison_kwargs["titles"] diff --git a/vall_e/inference.py b/vall_e/inference.py index fb6c40b..fa1b29d 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -11,6 +11,8 @@ from torch import Tensor from einops import rearrange from pathlib import Path +from tqdm import tqdm, trange + from .emb import g2p, qnt from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio from .emb.transcribe import transcribe @@ -213,7 +215,7 @@ class TTS(): input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) modality = sampling_kwargs.pop("modality", "auto") seed = sampling_kwargs.pop("seed", None) - tqdm = sampling_kwargs.pop("tqdm", True) + use_tqdm = sampling_kwargs.pop("tqdm", True) use_lora = sampling_kwargs.pop("use_lora", None) dtype = sampling_kwargs.pop("dtype", self.dtype) amp = sampling_kwargs.pop("amp", self.amp) @@ -256,7 +258,7 @@ class TTS(): inputs = [] # tensorfy inputs - for i in range( samples ): + for i in trange( samples, desc="Preparing batches" ): # detect language if languages[i] == "auto": languages[i] = g2p.detect_language( texts[i] ) @@ -295,14 +297,14 @@ class TTS(): buffer = ([], [], [], []) wavs = [] - for texts, proms, langs, out_paths in batches: + for texts, proms, langs, out_paths in tqdm(batches, desc="Processing batch"): seed = set_seed(seed) batch_size = len(texts) input_kwargs = dict( text_list=texts, proms_list=proms, lang_list=langs, - disable_tqdm=not tqdm, + disable_tqdm=not use_tqdm, use_lora=use_lora, ) @@ -355,7 +357,7 @@ class TTS(): input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) modality = sampling_kwargs.pop("modality", "auto") seed = sampling_kwargs.pop("seed", None) - tqdm = sampling_kwargs.pop("tqdm", True) + use_tqdm = sampling_kwargs.pop("tqdm", True) use_lora = sampling_kwargs.pop("use_lora", None) dtype = sampling_kwargs.pop("dtype", self.dtype) amp = sampling_kwargs.pop("amp", self.amp) @@ -405,7 +407,7 @@ class TTS(): if model is not None: text_list = model( text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"], - disable_tqdm=not tqdm, + disable_tqdm=not use_tqdm, use_lora=use_lora, **sampling_kwargs, ) @@ -452,7 +454,7 @@ class TTS(): text_list=[phns], proms_list=[prom], lang_list=[lang], - disable_tqdm=not tqdm, + disable_tqdm=not use_tqdm, use_lora=use_lora, ) if model_len is not None: