maybe fixes eval dataloader not shuffling under distributed

This commit is contained in:
mrq 2023-08-17 13:41:53 -05:00
parent 03872b823f
commit 18403a3523

View File

@ -321,17 +321,24 @@ def _seed_worker(worker_id):
def _create_dataloader(dataset, training):
sampler = None
shuffle = True
if cfg.distributed and training:
sampler = DistributedSampler(dataset)
shuffle = False
return DataLoader(
dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=False, # if cfg.distributed else True, # training
shuffle=shuffle,
drop_last=training,
num_workers=cfg.dataset.workers,
collate_fn=collate_fn,
persistent_workers=True,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler
sampler=sampler,
)
def _load_dataset_paths():