diff --git a/codes/train.py b/codes/train.py index 9a2e8a80..a3c4edd2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -210,7 +210,7 @@ class Trainer: self.tb_logger.add_scalar(k, v, self.current_step) if opt['wandb'] and self.rank <= 0: import wandb - wandb.log(logs, step=self.current_step) + wandb.log(logs, step=int(self.current_step * opt_get(opt, ['wandb_step_factor'], 1))) self.logger.info(message) #### save models and training states