Save & load amp state

This commit is contained in:
James Betker 2020-06-18 11:38:48 -06:00
parent 2e3b6bad77
commit efc80f041c

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import utils.util
from apex import amp
class BaseModel():
@ -109,6 +110,7 @@ class BaseModel():
state['schedulers'].append(s.state_dict())
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
state['amp'] = amp.state_dict()
save_filename = '{}.state'.format(iter_step)
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)
@ -129,3 +131,4 @@ class BaseModel():
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
amp.load_state_dict(resume_state['amp'])