Fix ddp for sampler
This commit is contained in:
parent
b521d94b01
commit
9dfe936c16
|
@ -114,9 +114,11 @@ class Trainer:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
||||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||||
|
shuffle = False
|
||||||
else:
|
else:
|
||||||
self.train_sampler = None
|
self.train_sampler = None
|
||||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn)
|
shuffle = True
|
||||||
|
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||||
len(self.train_set), train_size))
|
len(self.train_set), train_size))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user