Don't step if NaN losses are encountered.
This commit is contained in:
parent
ea500ad42a
commit
18dc62453f
|
@ -263,7 +263,7 @@ class ConfigurableStep(Module):
|
||||||
loss.clear_metrics()
|
loss.clear_metrics()
|
||||||
|
|
||||||
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
||||||
if train and isinstance(total_loss, torch.Tensor):
|
if train and isinstance(total_loss, torch.Tensor) and total_loss.isfinite():
|
||||||
loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||||
reset_required = total_loss < self.min_total_loss
|
reset_required = total_loss < self.min_total_loss
|
||||||
|
|
||||||
|
@ -282,6 +282,8 @@ class ConfigurableStep(Module):
|
||||||
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||||
|
|
||||||
self.grads_generated = True
|
self.grads_generated = True
|
||||||
|
elif not total_loss.isfinite():
|
||||||
|
print("Non-finite loss encountered. Skipping backwards step.")
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user