ExtensibleTrainer mods to support advanced checkpointing for stylegan2
Basically: stylegan2 makes use of gradient-based normalizers. These make it so that I cannot use gradient checkpointing. But I love gradient checkpointing. It makes things really, really fast and memory conscious. So - only don't checkpoint when we run the regularizer loss. This is a bit messy, but speeds up training by at least 20%. Also: pytorch: please make checkpointing a first class citizen.
This commit is contained in:
parent
db9e9e28a0
commit
44a19cd37c
|
@ -40,6 +40,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||
self.batch_factor = self.mega_batch_factor
|
||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
|
@ -144,17 +146,27 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||
self.updated = True
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
def feed_data(self, data, step, need_GT=True):
|
||||
self.env['step'] = step
|
||||
self.batch_factor = self.mega_batch_factor
|
||||
self.opt['checkpointing_enabled'] = self.checkpointing_cache
|
||||
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
|
||||
if 'mod_batch_factor' in self.opt['train'].keys() and \
|
||||
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
|
||||
self.batch_factor = self.opt['train']['mod_batch_factor']
|
||||
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
|
||||
self.opt['checkpointing_enabled'] = False
|
||||
|
||||
self.eval_state = {}
|
||||
for o in self.optimizers:
|
||||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)]
|
||||
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)]
|
||||
if need_GT:
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
|
||||
else:
|
||||
self.hq = self.lq
|
||||
self.ref = self.lq
|
||||
|
@ -162,11 +174,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
|
||||
for k, v in data.items():
|
||||
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
self.env['step'] = step
|
||||
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
if hasattr(net.module, "update_for_step"):
|
||||
|
@ -218,7 +228,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Now do a forward and backward pass for each gradient accumulation step.
|
||||
new_states = {}
|
||||
for m in range(self.mega_batch_factor):
|
||||
for m in range(self.batch_factor):
|
||||
ns = s.do_forward_backward(state, m, step_num, train=train_step)
|
||||
for k, v in ns.items():
|
||||
if k not in new_states.keys():
|
||||
|
|
Loading…
Reference in New Issue
Block a user