From f490eaeba7809c19904a66bd20066450aa6d9fd4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Mar 2022 10:38:51 -0700 Subject: [PATCH] Shuffle optimizer states back and forth between cpu memory during steps --- codes/trainer/ExtensibleTrainer.py | 2 ++ codes/trainer/base_model.py | 22 ++++++++++++++++++++++ codes/utils/util.py | 17 +++++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index a3db905d..e23dfa48 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -365,7 +365,9 @@ class ExtensibleTrainer(BaseModel): def consume_gradients(self, state, step, it): [e.before_optimize(state) for e in self.experiments] + self.restore_optimizers() step.do_step(it) + self.stash_optimizers() # Call into custom step hooks as well as update EMA params. for name, net in self.networks.items(): diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index ba635f90..220957fc 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -6,6 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn.parallel.distributed import DistributedDataParallel import utils.util +from utils.util import opt_get, optimizer_to class BaseModel(): @@ -18,6 +19,7 @@ class BaseModel(): self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.is_train = opt['is_train'] + self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False) self.schedulers = [] self.optimizers = [] self.disc_optimizers = [] @@ -158,6 +160,25 @@ class BaseModel(): utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'], save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename)) + def stash_optimizers(self): + """ + When enabled, puts all optimizer states in CPU memory, allowing forward and backward passes more memory + headroom. + """ + if not self.opt_in_cpu: + return + for opt in self.optimizers: + optimizer_to(opt, 'cpu') + + def restore_optimizers(self): + """ + Puts optimizer states back into device memory. + """ + if not self.opt_in_cpu: + return + for opt in self.optimizers: + optimizer_to(opt, self.device) + def resume_training(self, resume_state, load_amp=True): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state['optimizers'] @@ -171,3 +192,4 @@ class BaseModel(): if load_amp and 'amp' in resume_state.keys(): from apex import amp amp.load_state_dict(resume_state['amp']) + self.stash_optimizers() diff --git a/codes/utils/util.py b/codes/utils/util.py index f8d9fb26..bbc97772 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -508,3 +508,20 @@ def ceil_multiple(base, multiple): if res == 0: return base return base + (multiple - res) + + +def optimizer_to(opt, device): + """ + Pushes the optimizer params from opt onto the specified device. + """ + for param in opt.state.values(): + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) \ No newline at end of file