forked from mrq/DL-Art-School
adjust location of pre-optimizer step so I can visualize the new grad norms
This commit is contained in:
parent
58019a2ce3
commit
3c242403f5
|
@ -303,19 +303,24 @@ class ExtensibleTrainer(BaseModel):
|
||||||
raise OverwrittenStateError(k, list(state.keys()))
|
raise OverwrittenStateError(k, list(state.keys()))
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
if return_grad_norms and train_step:
|
|
||||||
for name in nets_to_train:
|
|
||||||
model = self.networks[name]
|
|
||||||
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
|
|
||||||
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
|
|
||||||
else:
|
|
||||||
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
|
||||||
for name in pgroups.keys():
|
|
||||||
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
|
||||||
|
|
||||||
|
|
||||||
# (Maybe) perform a step.
|
# (Maybe) perform a step.
|
||||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||||
|
# Call into pre-step hooks.
|
||||||
|
for name, net in self.networks.items():
|
||||||
|
if hasattr(net.module, "before_step"):
|
||||||
|
net.module.before_step(it)
|
||||||
|
|
||||||
|
if return_grad_norms and train_step:
|
||||||
|
for name in nets_to_train:
|
||||||
|
model = self.networks[name]
|
||||||
|
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
|
||||||
|
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
|
||||||
|
else:
|
||||||
|
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||||
|
for name in pgroups.keys():
|
||||||
|
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||||
|
|
||||||
self.consume_gradients(state, step, it)
|
self.consume_gradients(state, step, it)
|
||||||
|
|
||||||
|
|
||||||
|
@ -360,11 +365,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
def consume_gradients(self, state, step, it):
|
def consume_gradients(self, state, step, it):
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
# Call into pre-step hooks.
|
|
||||||
for name, net in self.networks.items():
|
|
||||||
if hasattr(net.module, "before_step"):
|
|
||||||
net.module.before_step(it)
|
|
||||||
|
|
||||||
step.do_step(it)
|
step.do_step(it)
|
||||||
|
|
||||||
# Call into custom step hooks as well as update EMA params.
|
# Call into custom step hooks as well as update EMA params.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user