diff --git a/vall_e/data.py b/vall_e/data.py index c06daf6..382c36a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -321,17 +321,24 @@ def _seed_worker(worker_id): def _create_dataloader(dataset, training): + sampler = None + shuffle = True + + if cfg.distributed and training: + sampler = DistributedSampler(dataset) + shuffle = False + return DataLoader( dataset=dataset, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, - shuffle=False, # if cfg.distributed else True, # training + shuffle=shuffle, drop_last=training, num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=True, pin_memory=False, # True, worker_init_fn=_seed_worker, - sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler + sampler=sampler, ) def _load_dataset_paths():