From 922b1d76df7330ad6f0957c4bc42519d1c181379 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 3 Oct 2020 11:09:09 -0600 Subject: [PATCH] Don't record visuals when not on rank 0 --- codes/models/ExtensibleTrainer.py | 2 +- codes/models/base_model.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 395f88bc..9d4339f8 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -226,7 +226,7 @@ class ExtensibleTrainer(BaseModel): [e.after_optimize(state) for e in self.experiments] # Record visual outputs for usage in debugging and testing. - if 'visuals' in self.opt['logger'].keys(): + if 'visuals' in self.opt['logger'].keys() and self.rank <= 0: sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") for v in self.opt['logger']['visuals']: if v not in state.keys(): diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 37af08cf..0b092520 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -10,6 +10,10 @@ from apex import amp class BaseModel(): def __init__(self, opt): self.opt = opt + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.is_train = opt['is_train']