forked from mrq/DL-Art-School
Shuffle optimizer states back and forth between cpu memory during steps
This commit is contained in:
parent
3c242403f5
commit
f490eaeba7
|
@ -365,7 +365,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
def consume_gradients(self, state, step, it):
|
def consume_gradients(self, state, step, it):
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
|
self.restore_optimizers()
|
||||||
step.do_step(it)
|
step.do_step(it)
|
||||||
|
self.stash_optimizers()
|
||||||
|
|
||||||
# Call into custom step hooks as well as update EMA params.
|
# Call into custom step hooks as well as update EMA params.
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import utils.util
|
import utils.util
|
||||||
|
from utils.util import opt_get, optimizer_to
|
||||||
|
|
||||||
|
|
||||||
class BaseModel():
|
class BaseModel():
|
||||||
|
@ -18,6 +19,7 @@ class BaseModel():
|
||||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||||
self.is_train = opt['is_train']
|
self.is_train = opt['is_train']
|
||||||
|
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
|
||||||
self.schedulers = []
|
self.schedulers = []
|
||||||
self.optimizers = []
|
self.optimizers = []
|
||||||
self.disc_optimizers = []
|
self.disc_optimizers = []
|
||||||
|
@ -158,6 +160,25 @@ class BaseModel():
|
||||||
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
||||||
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
|
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
|
||||||
|
|
||||||
|
def stash_optimizers(self):
|
||||||
|
"""
|
||||||
|
When enabled, puts all optimizer states in CPU memory, allowing forward and backward passes more memory
|
||||||
|
headroom.
|
||||||
|
"""
|
||||||
|
if not self.opt_in_cpu:
|
||||||
|
return
|
||||||
|
for opt in self.optimizers:
|
||||||
|
optimizer_to(opt, 'cpu')
|
||||||
|
|
||||||
|
def restore_optimizers(self):
|
||||||
|
"""
|
||||||
|
Puts optimizer states back into device memory.
|
||||||
|
"""
|
||||||
|
if not self.opt_in_cpu:
|
||||||
|
return
|
||||||
|
for opt in self.optimizers:
|
||||||
|
optimizer_to(opt, self.device)
|
||||||
|
|
||||||
def resume_training(self, resume_state, load_amp=True):
|
def resume_training(self, resume_state, load_amp=True):
|
||||||
"""Resume the optimizers and schedulers for training"""
|
"""Resume the optimizers and schedulers for training"""
|
||||||
resume_optimizers = resume_state['optimizers']
|
resume_optimizers = resume_state['optimizers']
|
||||||
|
@ -171,3 +192,4 @@ class BaseModel():
|
||||||
if load_amp and 'amp' in resume_state.keys():
|
if load_amp and 'amp' in resume_state.keys():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
amp.load_state_dict(resume_state['amp'])
|
amp.load_state_dict(resume_state['amp'])
|
||||||
|
self.stash_optimizers()
|
||||||
|
|
|
@ -508,3 +508,20 @@ def ceil_multiple(base, multiple):
|
||||||
if res == 0:
|
if res == 0:
|
||||||
return base
|
return base
|
||||||
return base + (multiple - res)
|
return base + (multiple - res)
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer_to(opt, device):
|
||||||
|
"""
|
||||||
|
Pushes the optimizer params from opt onto the specified device.
|
||||||
|
"""
|
||||||
|
for param in opt.state.values():
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
if param._grad is not None:
|
||||||
|
param._grad.data = param._grad.data.to(device)
|
||||||
|
elif isinstance(param, dict):
|
||||||
|
for subparam in param.values():
|
||||||
|
if isinstance(subparam, torch.Tensor):
|
||||||
|
subparam.data = subparam.data.to(device)
|
||||||
|
if subparam._grad is not None:
|
||||||
|
subparam._grad.data = subparam._grad.data.to(device)
|
Loading…
Reference in New Issue
Block a user