Build in capacity to revert & resume networks that encounter a NaN

I'm increasingly seeing issues where something like this can be useful. In many (most?)
cases it's just a waste of compute, though. Still, better than a cold computer for a whole
night.
This commit is contained in:
James Betker 2021-11-01 16:14:59 -06:00
parent 87364b890f
commit ee9b199d2b
2 changed files with 27 additions and 4 deletions

View File

@ -46,6 +46,7 @@ class ExtensibleTrainer(BaseModel):
self.batch_factor = self.mega_batch_factor
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
self.netsG = {}
self.netsD = {}
@ -260,10 +261,23 @@ class ExtensibleTrainer(BaseModel):
s.do_step(step)
if s.nan_counter > 10:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(step)
self.save_training_state(0, step)
raise ArithmeticError
if self.auto_recover is None:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(step)
self.save_training_state(0, step)
raise ArithmeticError
else:
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
for k, ps in self.save_history.keys():
if len(ps) < self.auto_recover:
print("Belay that - not enough saves were recorded. Failing instead.")
raise ArithmeticError
if k == '__state__':
self.resume_training(torch.load(ps[-self.auto_recover]))
else:
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():

View File

@ -20,6 +20,7 @@ class BaseModel():
self.schedulers = []
self.optimizers = []
self.disc_optimizers = []
self.save_history = {}
def feed_data(self, data):
pass
@ -89,6 +90,10 @@ class BaseModel():
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
if network_label not in self.save_history.keys():
self.save_history[network_label] = []
self.save_history[network_label].append(save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
@ -134,6 +139,10 @@ class BaseModel():
save_filename = '{}.state'.format(iter_step)
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)
if '__state__' not in self.save_history.keys():
self.save_history['__state__'] = []
self.save_history['__state__'].append(save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))