diff --git a/codes/train.py b/codes/train.py
index dfe5f5c3..1a9d118d 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -267,6 +267,10 @@ class Trainer:
                     import wandb
                     wandb.log(eval_dict)
 
+        # Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
+        for net in self.model.networks.values():
+            net.zero_grad()
+
     def do_training(self):
         self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
         for epoch in range(self.start_epoch, self.total_epochs + 1):
diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py
index 3f210edd..3ac97f25 100644
--- a/codes/trainer/steps.py
+++ b/codes/trainer/steps.py
@@ -322,9 +322,7 @@ class ConfigurableStep(Module):
                 self.scaler.step(opt)
                 self.scaler.update()
             else:
-                for pg in opt.param_groups:
-                    for p in pg['params']:
-                        p.grad = 0
+                opt.zero_grad()
 
     def get_metrics(self):
         return self.loss_accumulator.as_dict()