diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index edabeab0..1ff67c83 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -46,6 +46,7 @@ class ExtensibleTrainer(BaseModel): self.batch_factor = self.mega_batch_factor self.ema_rate = opt_get(train_opt, ['ema_rate'], .999) self.checkpointing_cache = opt['checkpointing_enabled'] + self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) self.netsG = {} self.netsD = {} @@ -260,10 +261,23 @@ class ExtensibleTrainer(BaseModel): s.do_step(step) if s.nan_counter > 10: - print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") - self.save(step) - self.save_training_state(0, step) - raise ArithmeticError + if self.auto_recover is None: + print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") + self.save(step) + self.save_training_state(0, step) + raise ArithmeticError + else: + print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.") + for k, ps in self.save_history.keys(): + if len(ps) < self.auto_recover: + print("Belay that - not enough saves were recorded. Failing instead.") + raise ArithmeticError + if k == '__state__': + self.resume_training(torch.load(ps[-self.auto_recover])) + else: + if k in self.networks.keys(): # This isn't always the case, for example for EMAs. + self.load_network(ps[-self.auto_recover], self.networks[k], strict=True) + self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) # Call into custom step hooks as well as update EMA params. for name, net in self.networks.items(): diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index ab655b33..e66edcb8 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -20,6 +20,7 @@ class BaseModel(): self.schedulers = [] self.optimizers = [] self.disc_optimizers = [] + self.save_history = {} def feed_data(self, data): pass @@ -89,6 +90,10 @@ class BaseModel(): for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) + if network_label not in self.save_history.keys(): + self.save_history[network_label] = [] + self.save_history[network_label].append(save_path) + # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. if 'alt_path' in self.opt['path'].keys(): torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename)) @@ -134,6 +139,10 @@ class BaseModel(): save_filename = '{}.state'.format(iter_step) save_path = os.path.join(self.opt['path']['training_state'], save_filename) torch.save(state, save_path) + if '__state__' not in self.save_history.keys(): + self.save_history['__state__'] = [] + self.save_history['__state__'].append(save_path) + # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. if 'alt_path' in self.opt['path'].keys(): torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))