diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 6f57e4be..c2fa5602 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -36,6 +36,8 @@ class ConfigurableStep(Module): # noticeable affect on training speed, but nowhere near as bad as anomaly_detection. self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False) self.nan_counter = 0 + # This is a similar mechanism plugged into the forward() pass. It cannot be turned off. + self.nan_loss_counter = 0 self.injectors = [] if 'injectors' in self.step_opt.keys(): @@ -282,8 +284,16 @@ class ConfigurableStep(Module): loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) self.grads_generated = True + # Reset nan_loss_counter + self.nan_loss_counter = 0 elif not total_loss.isfinite(): print("Non-finite loss encountered. Skipping backwards step.") + self.nan_loss_counter += 1 + if self.nan_loss_counter > 10: + print("Encountered 10 NaN losses in a row. Something is screwed up. Dumping model weights and exiting.") + if self.env['rank'] == 0: + torch.save(training_net.state_dict(), "nan_error_weights.pth") + exit(1) # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients.