Don't record visuals when not on rank 0
This commit is contained in:
parent
8197fd646f
commit
922b1d76df
|
@ -226,7 +226,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
[e.after_optimize(state) for e in self.experiments]
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys():
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
for v in self.opt['logger']['visuals']:
|
||||
if v not in state.keys():
|
||||
|
|
|
@ -10,6 +10,10 @@ from apex import amp
|
|||
class BaseModel():
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.rank = -1 # non dist training
|
||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||
self.is_train = opt['is_train']
|
||||
|
|
Loading…
Reference in New Issue
Block a user