ExtensibleTrainer - don't compute backward when there is no loss

This commit is contained in:
James Betker 2020-10-02 20:54:06 -06:00
parent 146a9125f2
commit 567b4d50a4

View File

@ -159,12 +159,14 @@ class ConfigurableStep(Module):
# Scale the loss down by the accumulation factor. # Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor'] total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads! # In some cases, the loss could not be set (e.g. all losses have 'after'
if self.env['amp']: if isinstance(loss, torch.Tensor):
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: # Get dem grads!
scaled_loss.backward() if self.env['amp']:
else: with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
total_loss.backward() scaled_loss.backward()
else:
total_loss.backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.