diff --git a/codes/train.py b/codes/train.py index f0727a98..242cf515 100644 --- a/codes/train.py +++ b/codes/train.py @@ -257,10 +257,11 @@ class Trainer: self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) - if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0: + if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: eval_dict = {} for eval in self.evaluators: - eval_dict.update(eval.perform_eval()) + if eval.uses_all_ddp() or self.rank <= 0: + eval_dict.update(eval.perform_eval()) if self.rank <= 0: print("Evaluator results: ", eval_dict) for ek, ev in eval_dict.items(): diff --git a/codes/trainer/eval/evaluator.py b/codes/trainer/eval/evaluator.py index d33705f1..519403c3 100644 --- a/codes/trainer/eval/evaluator.py +++ b/codes/trainer/eval/evaluator.py @@ -11,6 +11,7 @@ class Evaluator: self.model = model.module if hasattr(model, 'module') else model self.opt = opt_eval self.env = env + self.uses_all_ddp = opt_get(opt_eval, ['uses_all_ddp'], True) def perform_eval(self): return {}