More zero_grad fixes

This commit is contained in:
James Betker 2022-01-08 20:31:19 -07:00
parent 8bade38180
commit 894d245062
2 changed files with 5 additions and 3 deletions

View File

@ -267,6 +267,10 @@ class Trainer:
import wandb
wandb.log(eval_dict)
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self.model.networks.values():
net.zero_grad()
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):

View File

@ -322,9 +322,7 @@ class ConfigurableStep(Module):
self.scaler.step(opt)
self.scaler.update()
else:
for pg in opt.param_groups:
for p in pg['params']:
p.grad = 0
opt.zero_grad()
def get_metrics(self):
return self.loss_accumulator.as_dict()