diff --git a/codes/train.py b/codes/train.py index 5447cc1c..5acf30fe 100644 --- a/codes/train.py +++ b/codes/train.py @@ -238,6 +238,12 @@ class Trainer: self.tb_logger.add_scalar(k, v, self.current_step) if opt['wandb'] and self.rank <= 0: import wandb + wandb_logs = {} + for k, v in logs.items(): + if 'histogram' in k: + wandb_logs[k] = wandb.Histogram(v) + else: + wandb_logs[k] = v if opt_get(opt, ['wandb_progress_use_raw_steps'], False): wandb.log(logs, step=self.current_step) else: