adjust location of pre-optimizer step so I can visualize the new grad norms

This commit is contained in:
James Betker 2022-03-04 08:56:42 -07:00
parent 58019a2ce3
commit 3c242403f5

View File

@ -303,6 +303,14 @@ class ExtensibleTrainer(BaseModel):
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]
@ -313,9 +321,6 @@ class ExtensibleTrainer(BaseModel):
for name in pgroups.keys():
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
self.consume_gradients(state, step, it)
@ -360,11 +365,6 @@ class ExtensibleTrainer(BaseModel):
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
step.do_step(it)
# Call into custom step hooks as well as update EMA params.