Don't load amp state dict if amp is disabled
This commit is contained in:
parent
94deab2792
commit
d0321ca5de
|
@ -123,7 +123,7 @@ class BaseModel():
|
|||
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
||||
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
|
||||
|
||||
def resume_training(self, resume_state):
|
||||
def resume_training(self, resume_state, load_amp=True):
|
||||
"""Resume the optimizers and schedulers for training"""
|
||||
resume_optimizers = resume_state['optimizers']
|
||||
resume_schedulers = resume_state['schedulers']
|
||||
|
@ -133,5 +133,5 @@ class BaseModel():
|
|||
self.optimizers[i].load_state_dict(o)
|
||||
for i, s in enumerate(resume_schedulers):
|
||||
self.schedulers[i].load_state_dict(s)
|
||||
if 'amp' in resume_state.keys():
|
||||
if load_amp and 'amp' in resume_state.keys():
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user