Don't accumulate losses for metrics when the loss isn't a tensor
This commit is contained in:
parent
19a4075e1e
commit
8197fd646f
|
@ -152,7 +152,8 @@ class ConfigurableStep(Module):
|
|||
l = loss(self.training_net, local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
if isinstance(l, torch.Tensor):
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user