Shuffle optimizer states back and forth between cpu memory during steps
This commit is contained in:
parent
3c242403f5
commit
f490eaeba7
|
@ -365,7 +365,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def consume_gradients(self, state, step, it):
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
self.restore_optimizers()
|
||||
step.do_step(it)
|
||||
self.stash_optimizers()
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import utils.util
|
||||
from utils.util import opt_get, optimizer_to
|
||||
|
||||
|
||||
class BaseModel():
|
||||
|
@ -18,6 +19,7 @@ class BaseModel():
|
|||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||
self.is_train = opt['is_train']
|
||||
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
|
||||
self.schedulers = []
|
||||
self.optimizers = []
|
||||
self.disc_optimizers = []
|
||||
|
@ -158,6 +160,25 @@ class BaseModel():
|
|||
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
||||
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
|
||||
|
||||
def stash_optimizers(self):
|
||||
"""
|
||||
When enabled, puts all optimizer states in CPU memory, allowing forward and backward passes more memory
|
||||
headroom.
|
||||
"""
|
||||
if not self.opt_in_cpu:
|
||||
return
|
||||
for opt in self.optimizers:
|
||||
optimizer_to(opt, 'cpu')
|
||||
|
||||
def restore_optimizers(self):
|
||||
"""
|
||||
Puts optimizer states back into device memory.
|
||||
"""
|
||||
if not self.opt_in_cpu:
|
||||
return
|
||||
for opt in self.optimizers:
|
||||
optimizer_to(opt, self.device)
|
||||
|
||||
def resume_training(self, resume_state, load_amp=True):
|
||||
"""Resume the optimizers and schedulers for training"""
|
||||
resume_optimizers = resume_state['optimizers']
|
||||
|
@ -171,3 +192,4 @@ class BaseModel():
|
|||
if load_amp and 'amp' in resume_state.keys():
|
||||
from apex import amp
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
self.stash_optimizers()
|
||||
|
|
|
@ -508,3 +508,20 @@ def ceil_multiple(base, multiple):
|
|||
if res == 0:
|
||||
return base
|
||||
return base + (multiple - res)
|
||||
|
||||
|
||||
def optimizer_to(opt, device):
|
||||
"""
|
||||
Pushes the optimizer params from opt onto the specified device.
|
||||
"""
|
||||
for param in opt.state.values():
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
Loading…
Reference in New Issue
Block a user