demo page tweaks

This commit is contained in:
mrq 2024-10-10 13:52:37 -05:00
parent 2ea978f318
commit 96d05be73c
3 changed files with 9 additions and 5 deletions

View File

@ -25,6 +25,7 @@
<tr>
<th>Text</th>
<th>Prompt</th>
<th>Our VALL-E (No LoRA)</th>
<th>Our VALL-E</th>
<th>Ground Truth</th>
</tr>

View File

@ -72,7 +72,7 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
# Pull from validation dataset if existing + requested
if validation and cfg.dataset.validation:
paths = _load_paths(cfg.dataset.validation, type="validation")
paths = _load_paths(cfg.dataset.validation, type="validation", silent=True)
paths = list(itertools.chain.from_iterable(paths.values()))
for path in paths:
@ -527,8 +527,8 @@ def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
@cfg.diskcache()
def _load_paths(dataset, type="training"):
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths(dataset, type="training", silent=False):
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
def _load_paths_from_metadata(group_name, type="training", validate=False):
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name

View File

@ -124,7 +124,7 @@ def main():
# pull from dataset samples
if args.sample_from_dataset:
cfg.dataset.cache = False
cfg.dataset.sample_type = "speaker"
cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / "dataset"
@ -133,7 +133,7 @@ def main():
dataloader = create_train_dataloader()
_logger.info("Loaded dataloader.")
length = len( dataloader.dataset )
length = min(len( dataloader.dataset ), cfg.evaluation.batch_size)
num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ):
@ -233,6 +233,9 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if not args.lora:
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
# write demo page
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )