diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 5add1290..0ed4c09f 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -155,9 +155,9 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) - + # In some cases, the loss could not be set (e.g. all losses have 'after' - if isinstance(loss, torch.Tensor): + if isinstance(total_loss, torch.Tensor): self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) # Scale the loss down by the accumulation factor. total_loss = total_loss / self.env['mega_batch_factor']