adjust location of pre-optimizer step so I can visualize the new grad norms

This commit is contained in:
James Betker 2022-03-04 08:56:42 -07:00
parent 58019a2ce3
commit 3c242403f5

View File

@ -303,19 +303,24 @@ class ExtensibleTrainer(BaseModel):
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
else:
pgroups = {f'{name}_all_parameters': list(model.parameters())}
for name in pgroups.keys():
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
else:
pgroups = {f'{name}_all_parameters': list(model.parameters())}
for name in pgroups.keys():
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
self.consume_gradients(state, step, it)
@ -360,11 +365,6 @@ class ExtensibleTrainer(BaseModel):
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
step.do_step(it)
# Call into custom step hooks as well as update EMA params.