Fix evaluation when using multiple batch sizes

This commit is contained in:
James Betker 2022-04-05 07:51:09 -06:00
parent 572d137589
commit 3d916e7687

View File

@ -37,6 +37,7 @@ class Trainer:
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
self.current_step = 0
self.total_training_data_encountered = 0
self.next_eval_step = 0
#### loading resume state if exists
if opt['path'].get('resume_state', None):
@ -253,7 +254,9 @@ class Trainer:
val_freq = opt['train']['val_freq'] * batch_size
else:
val_freq = int(opt['train']['val_freq_megasamples'] * 1000000)
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0:
if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered > self.next_eval_step:
self.next_eval_step = self.total_training_data_encountered + val_freq
metrics = []
for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)