demo page tweaks
This commit is contained in:
parent
2ea978f318
commit
96d05be73c
|
@ -25,6 +25,7 @@
|
|||
<tr>
|
||||
<th>Text</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E (No LoRA)</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
|
|
|
@ -72,7 +72,7 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
|
|||
|
||||
# Pull from validation dataset if existing + requested
|
||||
if validation and cfg.dataset.validation:
|
||||
paths = _load_paths(cfg.dataset.validation, type="validation")
|
||||
paths = _load_paths(cfg.dataset.validation, type="validation", silent=True)
|
||||
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||
|
||||
for path in paths:
|
||||
|
@ -527,8 +527,8 @@ def _get_duration_map( type="training" ):
|
|||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
@cfg.diskcache()
|
||||
def _load_paths(dataset, type="training"):
|
||||
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
||||
def _load_paths(dataset, type="training", silent=False):
|
||||
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
||||
|
||||
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
||||
|
|
|
@ -124,7 +124,7 @@ def main():
|
|||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
cfg.dataset.cache = False
|
||||
cfg.dataset.sample_type = "speaker"
|
||||
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
||||
cfg.dataset.tasks_list = [ 'tts' ]
|
||||
|
||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||
|
@ -133,7 +133,7 @@ def main():
|
|||
dataloader = create_train_dataloader()
|
||||
_logger.info("Loaded dataloader.")
|
||||
|
||||
length = len( dataloader.dataset )
|
||||
length = min(len( dataloader.dataset ), cfg.evaluation.batch_size)
|
||||
num = args.dataset_samples if args.dataset_samples else length
|
||||
|
||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||
|
@ -233,6 +233,9 @@ def main():
|
|||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if not args.lora:
|
||||
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user