diff --git a/codes/models/base_model.py b/codes/models/base_model.py index f5013921..68959b3d 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -81,6 +81,9 @@ class BaseModel(): for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) + # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. + if 'alt_path' in self.opt['path'].keys(): + torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename)) return save_path def load_network(self, load_path, network, strict=True): @@ -105,6 +108,9 @@ class BaseModel(): save_filename = '{}.state'.format(iter_step) save_path = os.path.join(self.opt['path']['training_state'], save_filename) torch.save(state, save_path) + # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. + if 'alt_path' in self.opt['path'].keys(): + torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state')) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training"""