tweaked demo page script to sample speakers instead

This commit is contained in:
mrq 2024-09-28 10:50:26 -05:00
parent 2f1dca3089
commit a9fa0898a9

View File

@ -121,18 +121,20 @@ def main():
# pull from dataset samples # pull from dataset samples
if args.sample_from_dataset: if args.sample_from_dataset:
cfg.dataset.cache = False cfg.dataset.cache = False
cfg.dataset.sample_type = "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / "dataset" samples_dirs["dataset"] = args.demo_dir / "dataset"
_logger.info("Loading dataloader...") _logger.info("Loading dataloader...")
dataloader = create_train_dataloader() dataloader = create_train_dataloader()
_logger.info("Loaded dataloader.") _logger.info("Loaded dataloader.")
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
length = len( dataloader.dataset ) length = len( dataloader.dataset )
num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ): for i in trange( num, desc="Sampling dataset for samples" ):
idx = random.randint( 0, length ) batch = dataloader.dataset[i]
batch = dataloader.dataset[idx]
dir = args.demo_dir / "dataset" / f'{i}' dir = args.demo_dir / "dataset" / f'{i}'
@ -141,7 +143,7 @@ def main():
metadata = batch["metadata"] metadata = batch["metadata"]
text = metadata["text"] text = metadata["text"]
language = metadata["language"] language = metadata["language"].lower()
prompt = dir / "prompt.wav" prompt = dir / "prompt.wav"
reference = dir / "reference.wav" reference = dir / "reference.wav"