diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 1b8348d1..aeec78a6 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel import utils.util +from apex import amp class BaseModel(): @@ -109,6 +110,7 @@ class BaseModel(): state['schedulers'].append(s.state_dict()) for o in self.optimizers: state['optimizers'].append(o.state_dict()) + state['amp'] = amp.state_dict() save_filename = '{}.state'.format(iter_step) save_path = os.path.join(self.opt['path']['training_state'], save_filename) torch.save(state, save_path) @@ -129,3 +131,4 @@ class BaseModel(): self.optimizers[i].load_state_dict(o) for i, s in enumerate(resume_schedulers): self.schedulers[i].load_state_dict(s) + amp.load_state_dict(resume_state['amp'])