ExtensibleTrainer mods to support advanced checkpointing for stylegan2

Basically: stylegan2 makes use of gradient-based normalizers. These
make it so that I cannot use gradient checkpointing. But I love gradient
checkpointing. It makes things really, really fast and memory conscious.

So - only don't checkpoint when we run the regularizer loss. This is a
bit messy, but speeds up training by at least 20%.

Also: pytorch: please make checkpointing a first class citizen.
This commit is contained in:
James Betker 2020-11-12 15:44:47 -07:00
parent db9e9e28a0
commit 44a19cd37c

View File

@ -40,6 +40,8 @@ class ExtensibleTrainer(BaseModel):
if self.is_train: if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor'] self.mega_batch_factor = train_opt['mega_batch_factor']
self.env['mega_batch_factor'] = self.mega_batch_factor self.env['mega_batch_factor'] = self.mega_batch_factor
self.batch_factor = self.mega_batch_factor
self.checkpointing_cache = opt['checkpointing_enabled']
self.netsG = {} self.netsG = {}
self.netsD = {} self.netsD = {}
@ -144,17 +146,27 @@ class ExtensibleTrainer(BaseModel):
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True self.updated = True
def feed_data(self, data, need_GT=True): def feed_data(self, data, step, need_GT=True):
self.env['step'] = step
self.batch_factor = self.mega_batch_factor
self.opt['checkpointing_enabled'] = self.checkpointing_cache
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
if 'mod_batch_factor' in self.opt['train'].keys() and \
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
self.batch_factor = self.opt['train']['mod_batch_factor']
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
self.opt['checkpointing_enabled'] = False
self.eval_state = {} self.eval_state = {}
for o in self.optimizers: for o in self.optimizers:
o.zero_grad() o.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)]
if need_GT: if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
else: else:
self.hq = self.lq self.hq = self.lq
self.ref = self.lq self.ref = self.lq
@ -162,11 +174,9 @@ class ExtensibleTrainer(BaseModel):
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for k, v in data.items(): for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)] self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
def optimize_parameters(self, step): def optimize_parameters(self, step):
self.env['step'] = step
# Some models need to make parametric adjustments per-step. Do that here. # Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values(): for net in self.networks.values():
if hasattr(net.module, "update_for_step"): if hasattr(net.module, "update_for_step"):
@ -218,7 +228,7 @@ class ExtensibleTrainer(BaseModel):
# Now do a forward and backward pass for each gradient accumulation step. # Now do a forward and backward pass for each gradient accumulation step.
new_states = {} new_states = {}
for m in range(self.mega_batch_factor): for m in range(self.batch_factor):
ns = s.do_forward_backward(state, m, step_num, train=train_step) ns = s.do_forward_backward(state, m, step_num, train=train_step)
for k, v in ns.items(): for k, v in ns.items():
if k not in new_states.keys(): if k not in new_states.keys():