From 44a19cd37c21a2e4752c216c4e4a5779498f67d4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 12 Nov 2020 15:44:47 -0700 Subject: [PATCH] ExtensibleTrainer mods to support advanced checkpointing for stylegan2 Basically: stylegan2 makes use of gradient-based normalizers. These make it so that I cannot use gradient checkpointing. But I love gradient checkpointing. It makes things really, really fast and memory conscious. So - only don't checkpoint when we run the regularizer loss. This is a bit messy, but speeds up training by at least 20%. Also: pytorch: please make checkpointing a first class citizen. --- codes/models/ExtensibleTrainer.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 81d2746c..d7f1bd49 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -40,6 +40,8 @@ class ExtensibleTrainer(BaseModel): if self.is_train: self.mega_batch_factor = train_opt['mega_batch_factor'] self.env['mega_batch_factor'] = self.mega_batch_factor + self.batch_factor = self.mega_batch_factor + self.checkpointing_cache = opt['checkpointing_enabled'] self.netsG = {} self.netsD = {} @@ -144,17 +146,27 @@ class ExtensibleTrainer(BaseModel): # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. self.updated = True - def feed_data(self, data, need_GT=True): + def feed_data(self, data, step, need_GT=True): + self.env['step'] = step + self.batch_factor = self.mega_batch_factor + self.opt['checkpointing_enabled'] = self.checkpointing_cache + # The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory. + if 'mod_batch_factor' in self.opt['train'].keys() and \ + self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0: + self.batch_factor = self.opt['train']['mod_batch_factor'] + if self.opt['train']['mod_batch_factor_also_disable_checkpointing']: + self.opt['checkpointing_enabled'] = False + self.eval_state = {} for o in self.optimizers: o.zero_grad() torch.cuda.empty_cache() - self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] + self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] if need_GT: - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] + self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)] else: self.hq = self.lq self.ref = self.lq @@ -162,11 +174,9 @@ class ExtensibleTrainer(BaseModel): self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} for k, v in data.items(): if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): - self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)] + self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] def optimize_parameters(self, step): - self.env['step'] = step - # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): @@ -218,7 +228,7 @@ class ExtensibleTrainer(BaseModel): # Now do a forward and backward pass for each gradient accumulation step. new_states = {} - for m in range(self.mega_batch_factor): + for m in range(self.batch_factor): ns = s.do_forward_backward(state, m, step_num, train=train_step) for k, v in ns.items(): if k not in new_states.keys():