forked from mrq/DL-Art-School
Allow an alt_path for saving models and states
This commit is contained in:
parent
f911ef0d3e
commit
b95c4087d1
|
@ -81,6 +81,9 @@ class BaseModel():
|
|||
for key, param in state_dict.items():
|
||||
state_dict[key] = param.cpu()
|
||||
torch.save(state_dict, save_path)
|
||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||
if 'alt_path' in self.opt['path'].keys():
|
||||
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
|
||||
return save_path
|
||||
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
|
@ -105,6 +108,9 @@ class BaseModel():
|
|||
save_filename = '{}.state'.format(iter_step)
|
||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||
torch.save(state, save_path)
|
||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||
if 'alt_path' in self.opt['path'].keys():
|
||||
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
|
||||
|
||||
def resume_training(self, resume_state):
|
||||
"""Resume the optimizers and schedulers for training"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user