diff --git a/codes/train.py b/codes/train.py index a92e7b35..334ffab2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): self.epoch = epoch - if opt['dist']: + if self.opt['dist']: self.train_sampler.set_epoch(epoch) tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader