forked from mrq/DL-Art-School
No batch factors for eval
This commit is contained in:
parent
82fc69abfa
commit
04d14b3acc
|
@ -221,7 +221,7 @@ class Trainer:
|
|||
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
|
||||
metrics = []
|
||||
for val_data in tqdm(self.val_loader):
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)
|
||||
metrics.append(self.model.test())
|
||||
reduced_metrics = {}
|
||||
for metric in metrics:
|
||||
|
|
|
@ -157,7 +157,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||
self.updated = True
|
||||
|
||||
def feed_data(self, data, step, need_GT=True):
|
||||
def feed_data(self, data, step, need_GT=True, perform_micro_batching=True):
|
||||
self.env['step'] = step
|
||||
self.batch_factor = self.mega_batch_factor
|
||||
self.opt['checkpointing_enabled'] = self.checkpointing_cache
|
||||
|
@ -174,10 +174,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
batch_factor = self.batch_factor if perform_micro_batching else 1
|
||||
self.dstate = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
|
|
Loading…
Reference in New Issue
Block a user