forked from mrq/DL-Art-School
ExtensibleTrainer - don't compute backward when there is no loss
This commit is contained in:
parent
146a9125f2
commit
567b4d50a4
|
@ -159,12 +159,14 @@ class ConfigurableStep(Module):
|
||||||
# Scale the loss down by the accumulation factor.
|
# Scale the loss down by the accumulation factor.
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# In some cases, the loss could not be set (e.g. all losses have 'after'
|
||||||
if self.env['amp']:
|
if isinstance(loss, torch.Tensor):
|
||||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
# Get dem grads!
|
||||||
scaled_loss.backward()
|
if self.env['amp']:
|
||||||
else:
|
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||||
total_loss.backward()
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user