diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 81d2746c..d7f1bd49 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -40,6 +40,8 @@ class ExtensibleTrainer(BaseModel): if self.is_train: self.mega_batch_factor = train_opt['mega_batch_factor'] self.env['mega_batch_factor'] = self.mega_batch_factor + self.batch_factor = self.mega_batch_factor + self.checkpointing_cache = opt['checkpointing_enabled'] self.netsG = {} self.netsD = {} @@ -144,17 +146,27 @@ class ExtensibleTrainer(BaseModel): # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. self.updated = True - def feed_data(self, data, need_GT=True): + def feed_data(self, data, step, need_GT=True): + self.env['step'] = step + self.batch_factor = self.mega_batch_factor + self.opt['checkpointing_enabled'] = self.checkpointing_cache + # The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory. + if 'mod_batch_factor' in self.opt['train'].keys() and \ + self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0: + self.batch_factor = self.opt['train']['mod_batch_factor'] + if self.opt['train']['mod_batch_factor_also_disable_checkpointing']: + self.opt['checkpointing_enabled'] = False + self.eval_state = {} for o in self.optimizers: o.zero_grad() torch.cuda.empty_cache() - self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] + self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] if need_GT: - self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] + self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] - self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] + self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)] else: self.hq = self.lq self.ref = self.lq @@ -162,11 +174,9 @@ class ExtensibleTrainer(BaseModel): self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref} for k, v in data.items(): if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): - self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.mega_batch_factor, dim=0)] + self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] def optimize_parameters(self, step): - self.env['step'] = step - # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): @@ -218,7 +228,7 @@ class ExtensibleTrainer(BaseModel): # Now do a forward and backward pass for each gradient accumulation step. new_states = {} - for m in range(self.mega_batch_factor): + for m in range(self.batch_factor): ns = s.do_forward_backward(state, m, step_num, train=train_step) for k, v in ns.items(): if k not in new_states.keys():