diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 0ed4c09f..f7c9ca66 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -152,7 +152,8 @@ class ConfigurableStep(Module): l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name] # Record metrics. - self.loss_accumulator.add_loss(loss_name, l) + if isinstance(l, torch.Tensor): + self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)