diff --git a/codes/train.py b/codes/train.py index d582904f..65175e77 100644 --- a/codes/train.py +++ b/codes/train.py @@ -221,7 +221,7 @@ class Trainer: if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0: metrics = [] for val_data in tqdm(self.val_loader): - self.model.feed_data(val_data, self.current_step) + self.model.feed_data(val_data, self.current_step, perform_micro_batching=False) metrics.append(self.model.test()) reduced_metrics = {} for metric in metrics: diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1b08b281..a327bdb3 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -157,7 +157,7 @@ class ExtensibleTrainer(BaseModel): # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. self.updated = True - def feed_data(self, data, step, need_GT=True): + def feed_data(self, data, step, need_GT=True, perform_micro_batching=True): self.env['step'] = step self.batch_factor = self.mega_batch_factor self.opt['checkpointing_enabled'] = self.checkpointing_cache @@ -174,10 +174,11 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() + batch_factor = self.batch_factor if perform_micro_batching else 1 self.dstate = {} for k, v in data.items(): if isinstance(v, torch.Tensor): - self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] + self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] def optimize_parameters(self, step): # Some models need to make parametric adjustments per-step. Do that here.