diff --git a/codes/train.py b/codes/train.py index 35efdd85..a9aa6ca9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -81,7 +81,7 @@ class Trainer: self.opt = opt #### wandb init - if opt['wandb']: + if opt['wandb'] and self.rank <= 0: import wandb os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) wandb.init(project=opt['name'], dir=opt['path']['log']) @@ -193,7 +193,7 @@ class Trainer: # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: self.tb_logger.add_scalar(k, v, self.current_step) - if opt['wandb']: + if opt['wandb'] and self.rank <= 0: import wandb wandb.log(logs) self.logger.info(message) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index acb04d53..e5e25ae6 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -70,7 +70,7 @@ class ExtensibleTrainer(BaseModel): if not net['trainable']: new_net.eval() - if net['wandb_debug']: + if net['wandb_debug'] and self.rank <= 0: import wandb wandb.watch(new_net, log='all', log_freq=3)