From d0321ca5de108095dd9f0b981ecb80cb086cf52f Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 14 Sep 2020 15:21:42 -0600 Subject: [PATCH] Don't load amp state dict if amp is disabled --- codes/models/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/base_model.py b/codes/models/base_model.py index aeec4a65..37af08cf 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -123,7 +123,7 @@ class BaseModel(): utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'], save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename)) - def resume_training(self, resume_state): + def resume_training(self, resume_state, load_amp=True): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state['optimizers'] resume_schedulers = resume_state['schedulers'] @@ -133,5 +133,5 @@ class BaseModel(): self.optimizers[i].load_state_dict(o) for i, s in enumerate(resume_schedulers): self.schedulers[i].load_state_dict(s) - if 'amp' in resume_state.keys(): + if load_amp and 'amp' in resume_state.keys(): amp.load_state_dict(resume_state['amp'])