diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index f399beea..3f210edd 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -321,6 +321,10 @@ class ConfigurableStep(Module): if not nan_found: self.scaler.step(opt) self.scaler.update() + else: + for pg in opt.param_groups: + for p in pg['params']: + p.grad = 0 def get_metrics(self): return self.loss_accumulator.as_dict()