From 894d245062cbc41b608c8ed9d3d5eab834dd1c38 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 8 Jan 2022 20:31:19 -0700 Subject: [PATCH] More zero_grad fixes --- codes/train.py | 4 ++++ codes/trainer/steps.py | 4 +--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/codes/train.py b/codes/train.py index dfe5f5c3..1a9d118d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -267,6 +267,10 @@ class Trainer: import wandb wandb.log(eval_dict) + # Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs. + for net in self.model.networks.values(): + net.zero_grad() + def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 3f210edd..3ac97f25 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -322,9 +322,7 @@ class ConfigurableStep(Module): self.scaler.step(opt) self.scaler.update() else: - for pg in opt.param_groups: - for p in pg['params']: - p.grad = 0 + opt.zero_grad() def get_metrics(self): return self.loss_accumulator.as_dict()