Build in capacity to revert & resume networks that encounter a NaN
I'm increasingly seeing issues where something like this can be useful. In many (most?) cases it's just a waste of compute, though. Still, better than a cold computer for a whole night.
This commit is contained in:
parent
87364b890f
commit
ee9b199d2b
|
@ -46,6 +46,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.batch_factor = self.mega_batch_factor
|
||||
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
|
@ -260,10 +261,23 @@ class ExtensibleTrainer(BaseModel):
|
|||
s.do_step(step)
|
||||
|
||||
if s.nan_counter > 10:
|
||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||
self.save(step)
|
||||
self.save_training_state(0, step)
|
||||
raise ArithmeticError
|
||||
if self.auto_recover is None:
|
||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||
self.save(step)
|
||||
self.save_training_state(0, step)
|
||||
raise ArithmeticError
|
||||
else:
|
||||
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
||||
for k, ps in self.save_history.keys():
|
||||
if len(ps) < self.auto_recover:
|
||||
print("Belay that - not enough saves were recorded. Failing instead.")
|
||||
raise ArithmeticError
|
||||
if k == '__state__':
|
||||
self.resume_training(torch.load(ps[-self.auto_recover]))
|
||||
else:
|
||||
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
||||
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
||||
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
|
|
|
@ -20,6 +20,7 @@ class BaseModel():
|
|||
self.schedulers = []
|
||||
self.optimizers = []
|
||||
self.disc_optimizers = []
|
||||
self.save_history = {}
|
||||
|
||||
def feed_data(self, data):
|
||||
pass
|
||||
|
@ -89,6 +90,10 @@ class BaseModel():
|
|||
for key, param in state_dict.items():
|
||||
state_dict[key] = param.cpu()
|
||||
torch.save(state_dict, save_path)
|
||||
if network_label not in self.save_history.keys():
|
||||
self.save_history[network_label] = []
|
||||
self.save_history[network_label].append(save_path)
|
||||
|
||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||
if 'alt_path' in self.opt['path'].keys():
|
||||
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
|
||||
|
@ -134,6 +139,10 @@ class BaseModel():
|
|||
save_filename = '{}.state'.format(iter_step)
|
||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||
torch.save(state, save_path)
|
||||
if '__state__' not in self.save_history.keys():
|
||||
self.save_history['__state__'] = []
|
||||
self.save_history['__state__'].append(save_path)
|
||||
|
||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||
if 'alt_path' in self.opt['path'].keys():
|
||||
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
|
||||
|
|
Loading…
Reference in New Issue
Block a user