Only log the name of the first network in the total_loss training set

This commit is contained in:
James Betker 2020-09-12 16:07:09 -06:00
parent fb595e72a4
commit 5b85f891af

View File

@ -133,7 +133,7 @@ class ConfigurableStep(Module):
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'],), total_loss)
self.loss_accumulator.add_loss("%s_total" % (self.step_opt['training'][0],), total_loss)
# Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor']