diff --git a/vall_e/demo.py b/vall_e/demo.py index 48098f3..80da923 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -121,18 +121,20 @@ def main(): # pull from dataset samples if args.sample_from_dataset: cfg.dataset.cache = False + cfg.dataset.sample_type = "speaker" + cfg.dataset.tasks_list = [ 'tts' ] + samples_dirs["dataset"] = args.demo_dir / "dataset" _logger.info("Loading dataloader...") dataloader = create_train_dataloader() _logger.info("Loaded dataloader.") - num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size - length = len( dataloader.dataset ) + num = args.dataset_samples if args.dataset_samples else length + for i in trange( num, desc="Sampling dataset for samples" ): - idx = random.randint( 0, length ) - batch = dataloader.dataset[idx] + batch = dataloader.dataset[i] dir = args.demo_dir / "dataset" / f'{i}' @@ -141,7 +143,7 @@ def main(): metadata = batch["metadata"] text = metadata["text"] - language = metadata["language"] + language = metadata["language"].lower() prompt = dir / "prompt.wav" reference = dir / "reference.wav"