From 9dfe936c16a6c7664e2221cf9b66973506f06389 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 19 Aug 2021 16:45:34 -0600 Subject: [PATCH] Fix ddp for sampler --- codes/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/train.py b/codes/train.py index 7b2da295..7908f857 100644 --- a/codes/train.py +++ b/codes/train.py @@ -114,9 +114,11 @@ class Trainer: if opt['dist']: self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio) self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + shuffle = False else: self.train_sampler = None - self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn) + shuffle = True + self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle) if self.rank <= 0: self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(self.train_set), train_size))