From 2ad2b5643880fefdb2ae8b7728aca8f873f62789 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 6 Jun 2021 16:52:07 -0600 Subject: [PATCH] Don't do wandb except on rank 0 --- codes/train.py | 4 ++-- codes/trainer/ExtensibleTrainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/train.py b/codes/train.py index 35efdd85..a9aa6ca9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -81,7 +81,7 @@ class Trainer: self.opt = opt #### wandb init - if opt['wandb']: + if opt['wandb'] and self.rank <= 0: import wandb os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) wandb.init(project=opt['name'], dir=opt['path']['log']) @@ -193,7 +193,7 @@ class Trainer: # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: self.tb_logger.add_scalar(k, v, self.current_step) - if opt['wandb']: + if opt['wandb'] and self.rank <= 0: import wandb wandb.log(logs) self.logger.info(message) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index acb04d53..e5e25ae6 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -70,7 +70,7 @@ class ExtensibleTrainer(BaseModel): if not net['trainable']: new_net.eval() - if net['wandb_debug']: + if net['wandb_debug'] and self.rank <= 0: import wandb wandb.watch(new_net, log='all', log_freq=3)