TOTAL_loss, dumbo

This commit is contained in:
James Betker 2020-10-02 21:06:10 -06:00
parent 4e44fcd655
commit 39865ca3df

View File

@ -155,9 +155,9 @@ class ConfigurableStep(Module):
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
# In some cases, the loss could not be set (e.g. all losses have 'after'
if isinstance(loss, torch.Tensor):
if isinstance(total_loss, torch.Tensor):
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
# Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor']