Shuffle optimizer states back and forth between cpu memory during steps

This commit is contained in:
James Betker 2022-03-04 10:38:51 -07:00
parent 3c242403f5
commit f490eaeba7
3 changed files with 41 additions and 0 deletions

View File

@ -365,7 +365,9 @@ class ExtensibleTrainer(BaseModel):
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
self.restore_optimizers()
step.do_step(it)
self.stash_optimizers()
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():

View File

@ -6,6 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util
from utils.util import opt_get, optimizer_to
class BaseModel():
@ -18,6 +19,7 @@ class BaseModel():
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train']
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
self.schedulers = []
self.optimizers = []
self.disc_optimizers = []
@ -158,6 +160,25 @@ class BaseModel():
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
def stash_optimizers(self):
"""
When enabled, puts all optimizer states in CPU memory, allowing forward and backward passes more memory
headroom.
"""
if not self.opt_in_cpu:
return
for opt in self.optimizers:
optimizer_to(opt, 'cpu')
def restore_optimizers(self):
"""
Puts optimizer states back into device memory.
"""
if not self.opt_in_cpu:
return
for opt in self.optimizers:
optimizer_to(opt, self.device)
def resume_training(self, resume_state, load_amp=True):
"""Resume the optimizers and schedulers for training"""
resume_optimizers = resume_state['optimizers']
@ -171,3 +192,4 @@ class BaseModel():
if load_amp and 'amp' in resume_state.keys():
from apex import amp
amp.load_state_dict(resume_state['amp'])
self.stash_optimizers()

View File

@ -508,3 +508,20 @@ def ceil_multiple(base, multiple):
if res == 0:
return base
return base + (multiple - res)
def optimizer_to(opt, device):
"""
Pushes the optimizer params from opt onto the specified device.
"""
for param in opt.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)