More zero_grad fixes

This commit is contained in:
James Betker 2022-01-08 20:31:19 -07:00
parent 8bade38180
commit 894d245062
2 changed files with 5 additions and 3 deletions

View File

@ -267,6 +267,10 @@ class Trainer:
import wandb import wandb
wandb.log(eval_dict) wandb.log(eval_dict)
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self.model.networks.values():
net.zero_grad()
def do_training(self): def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1): for epoch in range(self.start_epoch, self.total_epochs + 1):

View File

@ -322,9 +322,7 @@ class ConfigurableStep(Module):
self.scaler.step(opt) self.scaler.step(opt)
self.scaler.update() self.scaler.update()
else: else:
for pg in opt.param_groups: opt.zero_grad()
for p in pg['params']:
p.grad = 0
def get_metrics(self): def get_metrics(self):
return self.loss_accumulator.as_dict() return self.loss_accumulator.as_dict()