Don't do wandb except on rank 0
This commit is contained in:
parent
7c5478bc2c
commit
2ad2b56438
|
@ -81,7 +81,7 @@ class Trainer:
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
|
||||||
#### wandb init
|
#### wandb init
|
||||||
if opt['wandb']:
|
if opt['wandb'] and self.rank <= 0:
|
||||||
import wandb
|
import wandb
|
||||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
||||||
wandb.init(project=opt['name'], dir=opt['path']['log'])
|
wandb.init(project=opt['name'], dir=opt['path']['log'])
|
||||||
|
@ -193,7 +193,7 @@ class Trainer:
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||||
if opt['wandb']:
|
if opt['wandb'] and self.rank <= 0:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.log(logs)
|
wandb.log(logs)
|
||||||
self.logger.info(message)
|
self.logger.info(message)
|
||||||
|
|
|
@ -70,7 +70,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
if not net['trainable']:
|
if not net['trainable']:
|
||||||
new_net.eval()
|
new_net.eval()
|
||||||
if net['wandb_debug']:
|
if net['wandb_debug'] and self.rank <= 0:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.watch(new_net, log='all', log_freq=3)
|
wandb.watch(new_net, log='all', log_freq=3)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user