From 8197fd646fdd0d1a78df510fc6d2c41b33af32e1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 3 Oct 2020 11:03:55 -0600 Subject: [PATCH] Don't accumulate losses for metrics when the loss isn't a tensor --- codes/models/steps/steps.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 0ed4c09f..f7c9ca66 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -152,7 +152,8 @@ class ConfigurableStep(Module): l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name] # Record metrics. - self.loss_accumulator.add_loss(loss_name, l) + if isinstance(l, torch.Tensor): + self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)