Fix ddp for sampler

This commit is contained in:
James Betker 2021-08-19 16:45:34 -06:00
parent b521d94b01
commit 9dfe936c16

View File

@ -114,9 +114,11 @@ class Trainer:
if opt['dist']:
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
shuffle = False
else:
self.train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn)
shuffle = True
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size))