forked from mrq/DL-Art-School
Don't accumulate losses for metrics when the loss isn't a tensor
This commit is contained in:
parent
19a4075e1e
commit
8197fd646f
|
@ -152,6 +152,7 @@ class ConfigurableStep(Module):
|
||||||
l = loss(self.training_net, local_state)
|
l = loss(self.training_net, local_state)
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
|
if isinstance(l, torch.Tensor):
|
||||||
self.loss_accumulator.add_loss(loss_name, l)
|
self.loss_accumulator.add_loss(loss_name, l)
|
||||||
for n, v in loss.extra_metrics():
|
for n, v in loss.extra_metrics():
|
||||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user