From 96d05be73c7c0f27bc8e15f420501a207a7ba972 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 10 Oct 2024 13:52:37 -0500 Subject: [PATCH] demo page tweaks --- data/demo/index.template.html | 1 + vall_e/data.py | 6 +++--- vall_e/demo.py | 7 +++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/data/demo/index.template.html b/data/demo/index.template.html index 23eeddc..788ecb0 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -25,6 +25,7 @@ Text Prompt + Our VALL-E (No LoRA) Our VALL-E Ground Truth diff --git a/vall_e/data.py b/vall_e/data.py index d80b179..5310c49 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -72,7 +72,7 @@ def get_random_prompts( validation=True, length=0, tokenized=False ): # Pull from validation dataset if existing + requested if validation and cfg.dataset.validation: - paths = _load_paths(cfg.dataset.validation, type="validation") + paths = _load_paths(cfg.dataset.validation, type="validation", silent=True) paths = list(itertools.chain.from_iterable(paths.values())) for path in paths: @@ -527,8 +527,8 @@ def _get_duration_map( type="training" ): return _durations_map[type] if type in _durations_map else {} @cfg.diskcache() -def _load_paths(dataset, type="training"): - return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } +def _load_paths(dataset, type="training", silent=False): + return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) } def _load_paths_from_metadata(group_name, type="training", validate=False): data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name diff --git a/vall_e/demo.py b/vall_e/demo.py index 1abfd8d..33910f2 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -124,7 +124,7 @@ def main(): # pull from dataset samples if args.sample_from_dataset: cfg.dataset.cache = False - cfg.dataset.sample_type = "speaker" + cfg.dataset.sample_type = "path" if args.lora else "speaker" cfg.dataset.tasks_list = [ 'tts' ] samples_dirs["dataset"] = args.demo_dir / "dataset" @@ -133,7 +133,7 @@ def main(): dataloader = create_train_dataloader() _logger.info("Loaded dataloader.") - length = len( dataloader.dataset ) + length = min(len( dataloader.dataset ), cfg.evaluation.batch_size) num = args.dataset_samples if args.dataset_samples else length for i in trange( num, desc="Sampling dataset for samples" ): @@ -233,6 +233,9 @@ def main(): # write audio into template html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) + if not args.lora: + html = html.replace("\n\t\t\t\t\tOur VALL-E (No LoRA)", "") + # write demo page open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )