diff --git a/codes/train.py b/codes/train.py index aed20554..5ea0e865 100644 --- a/codes/train.py +++ b/codes/train.py @@ -37,6 +37,7 @@ class Trainer: self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) self.current_step = 0 self.total_training_data_encountered = 0 + self.next_eval_step = 0 #### loading resume state if exists if opt['path'].get('resume_state', None): @@ -253,7 +254,9 @@ class Trainer: val_freq = opt['train']['val_freq'] * batch_size else: val_freq = int(opt['train']['val_freq_megasamples'] * 1000000) - if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered % val_freq == 0: + + if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered > self.next_eval_step: + self.next_eval_step = self.total_training_data_encountered + val_freq metrics = [] for val_data in tqdm(self.val_loader): self.model.feed_data(val_data, self.current_step, perform_micro_batching=False)