diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 924325cd..5add1290 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -155,12 +155,13 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) - self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) - # Scale the loss down by the accumulation factor. - total_loss = total_loss / self.env['mega_batch_factor'] - + # In some cases, the loss could not be set (e.g. all losses have 'after' if isinstance(loss, torch.Tensor): + self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) + # Scale the loss down by the accumulation factor. + total_loss = total_loss / self.env['mega_batch_factor'] + # Get dem grads! if self.env['amp']: with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: