From 3fdf2a63aaf901f16763fa632269b823915199f4 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 Mar 2023 01:18:25 +0000 Subject: [PATCH] fixes --- codes/train.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/codes/train.py b/codes/train.py index 7e1db180..70c9d7e1 100644 --- a/codes/train.py +++ b/codes/train.py @@ -309,14 +309,6 @@ class Trainer: state['dataset_debugger_state'] = self.dataset_debugger.get_state() self.model.save_training_state(state) self.logger.info('Saving models and training states.') - """ - if 'alt_path' in opt['path'].keys(): - import shutil - print("Synchronizing tb_logger to alt_path..") - alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") - shutil.rmtree(alt_tblogger, ignore_errors=True) - shutil.copytree(self.tb_logger_path, alt_tblogger) - """ do_eval = self.total_training_data_encountered > self.next_eval_step if do_eval: @@ -331,14 +323,6 @@ class Trainer: eval_dict.update(eval.perform_eval()) if self.rank <= 0: print("Evaluator results: ", eval_dict) - """ - for ek, ev in eval_dict.items(): - self.tb_logger.add_scalar(ek, ev, self.current_step) - - if opt['wandb']: - import wandb - wandb.log(eval_dict) - """ # Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs. for net in self.model.networks.values(): @@ -351,17 +335,20 @@ class Trainer: self.logger.info('Beginning validation.') metrics = [] - tq_ldr = tqdm(self.train_loader, desc="Validating") if self.use_tqdm else self.train_loader + tq_ldr = tqdm(self.val_loader, desc="Validating") if self.use_tqdm else self.val_loader for val_data in tq_ldr: self.model.feed_data(val_data, self.current_step, perform_micro_batching=False) metric = self.model.test() metrics.append(metric) - if self.rank <= 0: + if self.rank <= 0 and self.use_tqdm: logs = process_metrics( metrics ) - if self.use_tqdm: - tq_ldr.set_postfix( logs, refresh=True ) - self.logger.info(f'Validation Metrics: {json.dumps(logs)}') + tq_ldr.set_postfix( logs, refresh=True ) + + if self.rank <= 0: + logs = process_metrics( metrics ) + logs['it'] = self.current_step + self.logger.info(f'Validation Metrics: {json.dumps(logs)}') def do_training(self): if self.rank <= 0: