From e6387c7613589c19721e1a4e9009582b09445334 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Apr 2022 11:29:57 -0600 Subject: [PATCH] Fix eval logic to not run immediately --- codes/train.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/codes/train.py b/codes/train.py index 5ea0e865..70578721 100644 --- a/codes/train.py +++ b/codes/train.py @@ -37,7 +37,6 @@ class Trainer: self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) self.current_step = 0 self.total_training_data_encountered = 0 - self.next_eval_step = 0 #### loading resume state if exists if opt['path'].get('resume_state', None): @@ -169,6 +168,14 @@ class Trainer: self.total_training_data_encountered = self.current_step * opt['datasets']['train']['batch_size'] opt['current_step'] = self.current_step + #### validation + if 'val_freq' in opt['train'].keys(): + self.val_freq = opt['train']['val_freq'] * opt['datasets']['train']['batch_size'] + else: + self.val_freq = int(opt['train']['val_freq_megasamples'] * 1000000) + + self.next_eval_step = self.total_training_data_encountered + self.val_freq + def do_step(self, train_data): if self._profile: print("Data fetch: %f" % (time() - _t)) @@ -249,14 +256,10 @@ class Trainer: shutil.rmtree(alt_tblogger, ignore_errors=True) shutil.copytree(self.tb_logger_path, alt_tblogger) - #### validation - if 'val_freq' in opt['train'].keys(): - val_freq = opt['train']['val_freq'] * batch_size - else: - val_freq = int(opt['train']['val_freq_megasamples'] * 1000000) - - if opt_get(opt, ['eval', 'pure'], False) and self.total_training_data_encountered > self.next_eval_step: - self.next_eval_step = self.total_training_data_encountered + val_freq + do_eval = self.total_training_data_encountered > self.next_eval_step + if do_eval: + self.next_eval_step = self.total_training_data_encountered + self.val_freq + if opt_get(opt, ['eval', 'pure'], False) and do_eval: metrics = [] for val_data in tqdm(self.val_loader): self.model.feed_data(val_data, self.current_step, perform_micro_batching=False) @@ -278,7 +281,7 @@ class Trainer: import wandb wandb.log({f'eval_{k}': torch.stack(v).mean().item() for k,v in reduced_metrics.items()}) - if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: + if len(self.evaluators) != 0 and do_eval: eval_dict = {} for eval in self.evaluators: if eval.uses_all_ddp or self.rank <= 0: