forked from mrq/DL-Art-School
adjust location of pre-optimizer step so I can visualize the new grad norms
This commit is contained in:
parent
58019a2ce3
commit
3c242403f5
|
@ -303,19 +303,24 @@ class ExtensibleTrainer(BaseModel):
|
|||
raise OverwrittenStateError(k, list(state.keys()))
|
||||
state[k] = v
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
for name in nets_to_train:
|
||||
model = self.networks[name]
|
||||
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
|
||||
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
|
||||
else:
|
||||
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||
for name in pgroups.keys():
|
||||
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||
|
||||
|
||||
# (Maybe) perform a step.
|
||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||
# Call into pre-step hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
for name in nets_to_train:
|
||||
model = self.networks[name]
|
||||
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
|
||||
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
|
||||
else:
|
||||
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||
for name in pgroups.keys():
|
||||
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||
|
||||
self.consume_gradients(state, step, it)
|
||||
|
||||
|
||||
|
@ -360,11 +365,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def consume_gradients(self, state, step, it):
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
# Call into pre-step hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
step.do_step(it)
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
|
|
Loading…
Reference in New Issue
Block a user