diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 72ff97a7..a3db905d 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -303,19 +303,24 @@ class ExtensibleTrainer(BaseModel): raise OverwrittenStateError(k, list(state.keys())) state[k] = v - if return_grad_norms and train_step: - for name in nets_to_train: - model = self.networks[name] - if hasattr(model.module, 'get_grad_norm_parameter_groups'): - pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()} - else: - pgroups = {f'{name}_all_parameters': list(model.parameters())} - for name in pgroups.keys(): - grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2) - # (Maybe) perform a step. if train_step and optimize and self.batch_size_optimizer.should_step(it): + # Call into pre-step hooks. + for name, net in self.networks.items(): + if hasattr(net.module, "before_step"): + net.module.before_step(it) + + if return_grad_norms and train_step: + for name in nets_to_train: + model = self.networks[name] + if hasattr(model.module, 'get_grad_norm_parameter_groups'): + pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()} + else: + pgroups = {f'{name}_all_parameters': list(model.parameters())} + for name in pgroups.keys(): + grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2) + self.consume_gradients(state, step, it) @@ -360,11 +365,6 @@ class ExtensibleTrainer(BaseModel): def consume_gradients(self, state, step, it): [e.before_optimize(state) for e in self.experiments] - # Call into pre-step hooks. - for name, net in self.networks.items(): - if hasattr(net.module, "before_step"): - net.module.before_step(it) - step.do_step(it) # Call into custom step hooks as well as update EMA params.