Add capacity for models to self-report visuals

This commit is contained in:
James Betker 2020-10-21 11:08:03 -06:00
parent dca5cddb3b
commit 3c6e600e48

View File

@ -230,25 +230,30 @@ class ExtensibleTrainer(BaseModel):
[e.after_optimize(state) for e in self.experiments]
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
if step % self.opt['logger']['visual_debug_rate'] == 0:
for i, dbgv in enumerate(state[v]):
if 'recurrent_visual_indices' in self.opt['logger'].keys():
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = dbgv[:, rvi]
if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
for i, dbgv in enumerate(state[v]):
if 'recurrent_visual_indices' in self.opt['logger'].keys():
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = dbgv[:, rvi]
if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):
model_vdbg_dir = os.path.join(sample_save_path, net_name)
os.makedirs(model_vdbg_dir, exist_ok=True)
net.module.visual_dbg(step, model_vdbg_dir)
def compute_fea_loss(self, real, fake):
with torch.no_grad():