Add capacity for models to self-report visuals
This commit is contained in:
parent
dca5cddb3b
commit
3c6e600e48
|
@ -230,25 +230,30 @@ class ExtensibleTrainer(BaseModel):
|
|||
[e.after_optimize(state) for e in self.experiments]
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
for v in self.opt['logger']['visuals']:
|
||||
if v not in state.keys():
|
||||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
for i, dbgv in enumerate(state[v]):
|
||||
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||
rdbgv = dbgv[:, rvi]
|
||||
if rdbgv.shape[1] > 3:
|
||||
rdbgv = rdbgv[:, :3, :, :]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
else:
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
for i, dbgv in enumerate(state[v]):
|
||||
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||
rdbgv = dbgv[:, rvi]
|
||||
if rdbgv.shape[1] > 3:
|
||||
rdbgv = rdbgv[:, :3, :, :]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
else:
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
# Some models have their own specific visual debug routines.
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "visual_dbg"):
|
||||
model_vdbg_dir = os.path.join(sample_save_path, net_name)
|
||||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||
net.module.visual_dbg(step, model_vdbg_dir)
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in New Issue
Block a user