Allow an alt_path for saving models and states
This commit is contained in:
parent
f911ef0d3e
commit
b95c4087d1
|
@ -81,6 +81,9 @@ class BaseModel():
|
||||||
for key, param in state_dict.items():
|
for key, param in state_dict.items():
|
||||||
state_dict[key] = param.cpu()
|
state_dict[key] = param.cpu()
|
||||||
torch.save(state_dict, save_path)
|
torch.save(state_dict, save_path)
|
||||||
|
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||||
|
if 'alt_path' in self.opt['path'].keys():
|
||||||
|
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True):
|
def load_network(self, load_path, network, strict=True):
|
||||||
|
@ -105,6 +108,9 @@ class BaseModel():
|
||||||
save_filename = '{}.state'.format(iter_step)
|
save_filename = '{}.state'.format(iter_step)
|
||||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||||
torch.save(state, save_path)
|
torch.save(state, save_path)
|
||||||
|
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||||
|
if 'alt_path' in self.opt['path'].keys():
|
||||||
|
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
|
||||||
|
|
||||||
def resume_training(self, resume_state):
|
def resume_training(self, resume_state):
|
||||||
"""Resume the optimizers and schedulers for training"""
|
"""Resume the optimizers and schedulers for training"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user