diff --git a/data/demo/index.template.html b/data/demo/index.template.html
index 23eeddc..788ecb0 100644
--- a/data/demo/index.template.html
+++ b/data/demo/index.template.html
@@ -25,6 +25,7 @@
| Text |
Prompt |
+ Our VALL-E (No LoRA) |
Our VALL-E |
Ground Truth |
diff --git a/vall_e/data.py b/vall_e/data.py
index d80b179..5310c49 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -72,7 +72,7 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
# Pull from validation dataset if existing + requested
if validation and cfg.dataset.validation:
- paths = _load_paths(cfg.dataset.validation, type="validation")
+ paths = _load_paths(cfg.dataset.validation, type="validation", silent=True)
paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths:
@@ -527,8 +527,8 @@ def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
@cfg.diskcache()
-def _load_paths(dataset, type="training"):
- return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
+def _load_paths(dataset, type="training", silent=False):
+ return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
def _load_paths_from_metadata(group_name, type="training", validate=False):
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 1abfd8d..33910f2 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -124,7 +124,7 @@ def main():
# pull from dataset samples
if args.sample_from_dataset:
cfg.dataset.cache = False
- cfg.dataset.sample_type = "speaker"
+ cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / "dataset"
@@ -133,7 +133,7 @@ def main():
dataloader = create_train_dataloader()
_logger.info("Loaded dataloader.")
- length = len( dataloader.dataset )
+ length = min(len( dataloader.dataset ), cfg.evaluation.batch_size)
num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ):
@@ -233,6 +233,9 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
+ if not args.lora:
+ html = html.replace("\n\t\t\t\t\tOur VALL-E (No LoRA) | ", "")
+
# write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )