maybe fixes eval dataloader not shuffling under distributed

This commit is contained in:
mrq 2023-08-17 13:41:53 -05:00
parent 03872b823f
commit 18403a3523

View File

@ -321,17 +321,24 @@ def _seed_worker(worker_id):
def _create_dataloader(dataset, training): def _create_dataloader(dataset, training):
sampler = None
shuffle = True
if cfg.distributed and training:
sampler = DistributedSampler(dataset)
shuffle = False
return DataLoader( return DataLoader(
dataset=dataset, dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=False, # if cfg.distributed else True, # training shuffle=shuffle,
drop_last=training, drop_last=training,
num_workers=cfg.dataset.workers, num_workers=cfg.dataset.workers,
collate_fn=collate_fn, collate_fn=collate_fn,
persistent_workers=True, persistent_workers=True,
pin_memory=False, # True, pin_memory=False, # True,
worker_init_fn=_seed_worker, worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler sampler=sampler,
) )
def _load_dataset_paths(): def _load_dataset_paths():