diff --git a/codes/requirements.txt b/codes/requirements.txt index 5a8d5955..e6a0e950 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -18,4 +18,5 @@ vector_quantize_pytorch orjson einops gsa-pytorch -lambda-networks \ No newline at end of file +lambda-networks +pytorch_ssim \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 1b580a8f..16b5b44c 100644 --- a/codes/train.py +++ b/codes/train.py @@ -269,6 +269,9 @@ class Trainer: print("Evaluator results: ", eval_dict) for ek, ev in eval_dict.items(): self.tb_logger.add_scalar(ek, ev, self.current_step) + if opt['wandb']: + wandb.log(eval_dict) + def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))