diff --git a/codes/train.py b/codes/train.py index 16b5b44c..6c55bb10 100644 --- a/codes/train.py +++ b/codes/train.py @@ -270,6 +270,7 @@ class Trainer: for ek, ev in eval_dict.items(): self.tb_logger.add_scalar(ek, ev, self.current_step) if opt['wandb']: + import wandb wandb.log(eval_dict)