abort early if losses reach nan too much, and save the model

This commit is contained in:
James Betker 2022-02-24 20:55:30 -07:00
parent 18dc62453f
commit f458f5d8f1

View File

@ -36,6 +36,8 @@ class ConfigurableStep(Module):
# noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
self.nan_counter = 0
# This is a similar mechanism plugged into the forward() pass. It cannot be turned off.
self.nan_loss_counter = 0
self.injectors = []
if 'injectors' in self.step_opt.keys():
@ -282,8 +284,16 @@ class ConfigurableStep(Module):
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
self.grads_generated = True
# Reset nan_loss_counter
self.nan_loss_counter = 0
elif not total_loss.isfinite():
print("Non-finite loss encountered. Skipping backwards step.")
self.nan_loss_counter += 1
if self.nan_loss_counter > 10:
print("Encountered 10 NaN losses in a row. Something is screwed up. Dumping model weights and exiting.")
if self.env['rank'] == 0:
torch.save(training_net.state_dict(), "nan_error_weights.pth")
exit(1)
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.