forked from mrq/DL-Art-School
More zero_grad fixes
This commit is contained in:
parent
8bade38180
commit
894d245062
|
@ -267,6 +267,10 @@ class Trainer:
|
|||
import wandb
|
||||
wandb.log(eval_dict)
|
||||
|
||||
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
|
||||
for net in self.model.networks.values():
|
||||
net.zero_grad()
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
|
|
|
@ -322,9 +322,7 @@ class ConfigurableStep(Module):
|
|||
self.scaler.step(opt)
|
||||
self.scaler.update()
|
||||
else:
|
||||
for pg in opt.param_groups:
|
||||
for p in pg['params']:
|
||||
p.grad = 0
|
||||
opt.zero_grad()
|
||||
|
||||
def get_metrics(self):
|
||||
return self.loss_accumulator.as_dict()
|
||||
|
|
Loading…
Reference in New Issue
Block a user