diff --git a/codes/models/lightweight_gan.py b/codes/models/lightweight_gan.py index 340a3c63..8a069fc8 100644 --- a/codes/models/lightweight_gan.py +++ b/codes/models/lightweight_gan.py @@ -259,7 +259,7 @@ class EMAWrapper(nn.Module): new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer) ma_buffer.copy_(new_buffer_value) - def custom_optimizer_step(self, step): + def after_step(self, step): if step % self.steps_per_ema == 0: self.update_moving_average() if step % self.steps_per_reset and step < self.steps_after_no_reset: diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 3ef20e07..72ff97a7 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -360,12 +360,17 @@ class ExtensibleTrainer(BaseModel): def consume_gradients(self, state, step, it): [e.before_optimize(state) for e in self.experiments] + # Call into pre-step hooks. + for name, net in self.networks.items(): + if hasattr(net.module, "before_step"): + net.module.before_step(it) + step.do_step(it) # Call into custom step hooks as well as update EMA params. for name, net in self.networks.items(): - if hasattr(net, "custom_optimizer_step"): - net.custom_optimizer_step(it) + if hasattr(net.module, "after_step"): + net.module.after_step(it) if self.do_emas: ema_params = self.emas[name].parameters() net_params = net.parameters() diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index c2fa5602..c981b87c 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -27,7 +27,6 @@ class ConfigurableStep(Module): self.optimizers = None self.scaler = GradScaler(enabled=self.opt['fp16'] or opt_get(self.opt, ['grad_scaler_enabled'], False)) self.grads_generated = False - self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999 self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None) # This is a half-measure that can be used between anomaly_detection and running a potentially problematic @@ -267,22 +266,12 @@ class ConfigurableStep(Module): # In some cases, the loss could not be set (e.g. all losses have 'after') if train and isinstance(total_loss, torch.Tensor) and total_loss.isfinite(): loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) - reset_required = total_loss < self.min_total_loss # Scale the loss down by the accumulation factor. total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! self.scaler.scale(total_loss).backward() - - if reset_required: - # You might be scratching your head at this. Why would you zero grad as opposed to not doing a - # backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good - # way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or - # implement it at the loss level. - self.get_network_for_name(self.step_opt['training']).zero_grad() - loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) - self.grads_generated = True # Reset nan_loss_counter self.nan_loss_counter = 0