demo page tweaks
This commit is contained in:
parent
2ea978f318
commit
96d05be73c
|
@ -25,6 +25,7 @@
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
|
<th>Our VALL-E (No LoRA)</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
|
@ -72,7 +72,7 @@ def get_random_prompts( validation=True, length=0, tokenized=False ):
|
||||||
|
|
||||||
# Pull from validation dataset if existing + requested
|
# Pull from validation dataset if existing + requested
|
||||||
if validation and cfg.dataset.validation:
|
if validation and cfg.dataset.validation:
|
||||||
paths = _load_paths(cfg.dataset.validation, type="validation")
|
paths = _load_paths(cfg.dataset.validation, type="validation", silent=True)
|
||||||
paths = list(itertools.chain.from_iterable(paths.values()))
|
paths = list(itertools.chain.from_iterable(paths.values()))
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
@ -527,8 +527,8 @@ def _get_duration_map( type="training" ):
|
||||||
return _durations_map[type] if type in _durations_map else {}
|
return _durations_map[type] if type in _durations_map else {}
|
||||||
|
|
||||||
@cfg.diskcache()
|
@cfg.diskcache()
|
||||||
def _load_paths(dataset, type="training"):
|
def _load_paths(dataset, type="training", silent=False):
|
||||||
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
|
return { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
||||||
|
|
||||||
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
||||||
|
|
|
@ -124,7 +124,7 @@ def main():
|
||||||
# pull from dataset samples
|
# pull from dataset samples
|
||||||
if args.sample_from_dataset:
|
if args.sample_from_dataset:
|
||||||
cfg.dataset.cache = False
|
cfg.dataset.cache = False
|
||||||
cfg.dataset.sample_type = "speaker"
|
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
||||||
cfg.dataset.tasks_list = [ 'tts' ]
|
cfg.dataset.tasks_list = [ 'tts' ]
|
||||||
|
|
||||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||||
|
@ -133,7 +133,7 @@ def main():
|
||||||
dataloader = create_train_dataloader()
|
dataloader = create_train_dataloader()
|
||||||
_logger.info("Loaded dataloader.")
|
_logger.info("Loaded dataloader.")
|
||||||
|
|
||||||
length = len( dataloader.dataset )
|
length = min(len( dataloader.dataset ), cfg.evaluation.batch_size)
|
||||||
num = args.dataset_samples if args.dataset_samples else length
|
num = args.dataset_samples if args.dataset_samples else length
|
||||||
|
|
||||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||||
|
@ -233,6 +233,9 @@ def main():
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
|
if not args.lora:
|
||||||
|
html = html.replace("\n\t\t\t\t\t<th>Our VALL-E (No LoRA)</th>", "")
|
||||||
|
|
||||||
# write demo page
|
# write demo page
|
||||||
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
open( args.demo_dir / "index.html", "w", encoding="utf-8" ).write( html )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user