Don't record visuals when not on rank 0
This commit is contained in:
parent
8197fd646f
commit
922b1d76df
|
@ -226,7 +226,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys():
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
|
||||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||||
for v in self.opt['logger']['visuals']:
|
for v in self.opt['logger']['visuals']:
|
||||||
if v not in state.keys():
|
if v not in state.keys():
|
||||||
|
|
|
@ -10,6 +10,10 @@ from apex import amp
|
||||||
class BaseModel():
|
class BaseModel():
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
if opt['dist']:
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self.rank = -1 # non dist training
|
||||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||||
self.is_train = opt['is_train']
|
self.is_train = opt['is_train']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user