diff --git a/codes/train.py b/codes/train.py index dfe5f5c3..1a9d118d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -267,6 +267,10 @@ class Trainer: import wandb wandb.log(eval_dict) + # Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs. + for net in self.model.networks.values(): + net.zero_grad() + def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 3f210edd..3ac97f25 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -322,9 +322,7 @@ class ConfigurableStep(Module): self.scaler.step(opt) self.scaler.update() else: - for pg in opt.param_groups: - for p in pg['params']: - p.grad = 0 + opt.zero_grad() def get_metrics(self): return self.loss_accumulator.as_dict()