Save & load amp state
This commit is contained in:
parent
2e3b6bad77
commit
efc80f041c
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
import utils.util
|
import utils.util
|
||||||
|
from apex import amp
|
||||||
|
|
||||||
|
|
||||||
class BaseModel():
|
class BaseModel():
|
||||||
|
@ -109,6 +110,7 @@ class BaseModel():
|
||||||
state['schedulers'].append(s.state_dict())
|
state['schedulers'].append(s.state_dict())
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
state['optimizers'].append(o.state_dict())
|
state['optimizers'].append(o.state_dict())
|
||||||
|
state['amp'] = amp.state_dict()
|
||||||
save_filename = '{}.state'.format(iter_step)
|
save_filename = '{}.state'.format(iter_step)
|
||||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||||
torch.save(state, save_path)
|
torch.save(state, save_path)
|
||||||
|
@ -129,3 +131,4 @@ class BaseModel():
|
||||||
self.optimizers[i].load_state_dict(o)
|
self.optimizers[i].load_state_dict(o)
|
||||||
for i, s in enumerate(resume_schedulers):
|
for i, s in enumerate(resume_schedulers):
|
||||||
self.schedulers[i].load_state_dict(s)
|
self.schedulers[i].load_state_dict(s)
|
||||||
|
amp.load_state_dict(resume_state['amp'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user