Don't record visuals when not on rank 0

This commit is contained in:
James Betker 2020-10-03 11:09:09 -06:00
parent 8197fd646f
commit 922b1d76df
2 changed files with 5 additions and 1 deletions

View File

@ -226,7 +226,7 @@ class ExtensibleTrainer(BaseModel):
[e.after_optimize(state) for e in self.experiments] [e.after_optimize(state) for e in self.experiments]
# Record visual outputs for usage in debugging and testing. # Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys(): if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']: for v in self.opt['logger']['visuals']:
if v not in state.keys(): if v not in state.keys():

View File

@ -10,6 +10,10 @@ from apex import amp
class BaseModel(): class BaseModel():
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train'] self.is_train = opt['is_train']