From 4e44fcd655e857edec8cdbc28a4b1d6c86683817 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 2 Oct 2020 20:55:33 -0600 Subject: [PATCH] Loss accumulator fix --- codes/models/steps/steps.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 924325cd..5add1290 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -155,12 +155,13 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) - self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) - # Scale the loss down by the accumulation factor. - total_loss = total_loss / self.env['mega_batch_factor'] - + # In some cases, the loss could not be set (e.g. all losses have 'after' if isinstance(loss, torch.Tensor): + self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) + # Scale the loss down by the accumulation factor. + total_loss = total_loss / self.env['mega_batch_factor'] + # Get dem grads! if self.env['amp']: with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: