diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 05a48f43..924325cd 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -159,12 +159,14 @@ class ConfigurableStep(Module): # Scale the loss down by the accumulation factor. total_loss = total_loss / self.env['mega_batch_factor'] - # Get dem grads! - if self.env['amp']: - with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: - scaled_loss.backward() - else: - total_loss.backward() + # In some cases, the loss could not be set (e.g. all losses have 'after' + if isinstance(loss, torch.Tensor): + # Get dem grads! + if self.env['amp']: + with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: + scaled_loss.backward() + else: + total_loss.backward() # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients.