Fix ddp for sampler
This commit is contained in:
parent
b521d94b01
commit
9dfe936c16
|
@ -114,9 +114,11 @@ class Trainer:
|
|||
if opt['dist']:
|
||||
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
shuffle = False
|
||||
else:
|
||||
self.train_sampler = None
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn)
|
||||
shuffle = True
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(self.train_set), train_size))
|
||||
|
|
Loading…
Reference in New Issue
Block a user