No batch factors for eval

This commit is contained in:
James Betker 2021-08-09 16:02:01 -06:00
parent 82fc69abfa
commit 04d14b3acc
2 changed files with 4 additions and 3 deletions

View File

@ -221,7 +221,7 @@ class Trainer:
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
metrics = []
for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step)
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
metrics.append(self.model.test())
reduced_metrics = {}
for metric in metrics:

View File

@ -157,7 +157,7 @@ class ExtensibleTrainer(BaseModel):
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data, step, need_GT=True):
def feed_data(self, data, step, need_GT=True, perform_micro_batching=True):
self.env['step'] = step
self.batch_factor = self.mega_batch_factor
self.opt['checkpointing_enabled'] = self.checkpointing_cache
@ -174,10 +174,11 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad()
torch.cuda.empty_cache()
batch_factor = self.batch_factor if perform_micro_batching else 1
self.dstate = {}
for k, v in data.items():
if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
def optimize_parameters(self, step):
# Some models need to make parametric adjustments per-step. Do that here.