abort early if losses reach nan too much, and save the model
This commit is contained in:
parent
18dc62453f
commit
f458f5d8f1
|
@ -36,6 +36,8 @@ class ConfigurableStep(Module):
|
|||
# noticeable affect on training speed, but nowhere near as bad as anomaly_detection.
|
||||
self.check_grads_for_nan = opt_get(opt_step, ['check_grads_for_nan'], False)
|
||||
self.nan_counter = 0
|
||||
# This is a similar mechanism plugged into the forward() pass. It cannot be turned off.
|
||||
self.nan_loss_counter = 0
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
|
@ -282,8 +284,16 @@ class ConfigurableStep(Module):
|
|||
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
|
||||
self.grads_generated = True
|
||||
# Reset nan_loss_counter
|
||||
self.nan_loss_counter = 0
|
||||
elif not total_loss.isfinite():
|
||||
print("Non-finite loss encountered. Skipping backwards step.")
|
||||
self.nan_loss_counter += 1
|
||||
if self.nan_loss_counter > 10:
|
||||
print("Encountered 10 NaN losses in a row. Something is screwed up. Dumping model weights and exiting.")
|
||||
if self.env['rank'] == 0:
|
||||
torch.save(training_net.state_dict(), "nan_error_weights.pth")
|
||||
exit(1)
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
|
|
Loading…
Reference in New Issue
Block a user