diff --git a/codes/train.py b/codes/train.py index 7b2da295..7908f857 100644 --- a/codes/train.py +++ b/codes/train.py @@ -114,9 +114,11 @@ class Trainer: if opt['dist']: self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio) self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + shuffle = False else: self.train_sampler = None - self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn) + shuffle = True + self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle) if self.rank <= 0: self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(self.train_set), train_size))