Don't load amp state dict if amp is disabled

This commit is contained in:
James Betker 2020-09-14 15:21:42 -06:00
parent 94deab2792
commit d0321ca5de

View File

@ -123,7 +123,7 @@ class BaseModel():
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
def resume_training(self, resume_state):
def resume_training(self, resume_state, load_amp=True):
"""Resume the optimizers and schedulers for training"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
@ -133,5 +133,5 @@ class BaseModel():
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
if 'amp' in resume_state.keys():
if load_amp and 'amp' in resume_state.keys():
amp.load_state_dict(resume_state['amp'])