diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 15e99aa1..6f57e4be 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -263,7 +263,7 @@ class ConfigurableStep(Module): loss.clear_metrics() # In some cases, the loss could not be set (e.g. all losses have 'after') - if train and isinstance(total_loss, torch.Tensor): + if train and isinstance(total_loss, torch.Tensor) and total_loss.isfinite(): loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) reset_required = total_loss < self.min_total_loss @@ -282,6 +282,8 @@ class ConfigurableStep(Module): loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) self.grads_generated = True + elif not total_loss.isfinite(): + print("Non-finite loss encountered. Skipping backwards step.") # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients.