Don't accumulate losses for metrics when the loss isn't a tensor

This commit is contained in:
James Betker 2020-10-03 11:03:55 -06:00
parent 19a4075e1e
commit 8197fd646f

View File

@ -152,7 +152,8 @@ class ConfigurableStep(Module):
l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name]
# Record metrics.
self.loss_accumulator.add_loss(loss_name, l)
if isinstance(l, torch.Tensor):
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)