Allow an alt_path for saving models and states

This commit is contained in:
James Betker 2020-05-16 09:10:51 -06:00
parent f911ef0d3e
commit b95c4087d1

View File

@ -81,6 +81,9 @@ class BaseModel():
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
return save_path
def load_network(self, load_path, network, strict=True):
@ -105,6 +108,9 @@ class BaseModel():
save_filename = '{}.state'.format(iter_step)
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
def resume_training(self, resume_state):
"""Resume the optimizers and schedulers for training"""