Add capacity for models to self-report visuals

This commit is contained in:
James Betker 2020-10-21 11:08:03 -06:00
parent dca5cddb3b
commit 3c6e600e48

View File

@ -230,25 +230,30 @@ class ExtensibleTrainer(BaseModel):
[e.after_optimize(state) for e in self.experiments] [e.after_optimize(state) for e in self.experiments]
# Record visual outputs for usage in debugging and testing. # Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0: if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']: for v in self.opt['logger']['visuals']:
if v not in state.keys(): if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it. continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
if step % self.opt['logger']['visual_debug_rate'] == 0: for i, dbgv in enumerate(state[v]):
for i, dbgv in enumerate(state[v]): if 'recurrent_visual_indices' in self.opt['logger'].keys():
if 'recurrent_visual_indices' in self.opt['logger'].keys(): for rvi in self.opt['logger']['recurrent_visual_indices']:
for rvi in self.opt['logger']['recurrent_visual_indices']: rdbgv = dbgv[:, rvi]
rdbgv = dbgv[:, rvi] if rdbgv.shape[1] > 3:
if rdbgv.shape[1] > 3: rdbgv = rdbgv[:, :3, :, :]
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):
model_vdbg_dir = os.path.join(sample_save_path, net_name)
os.makedirs(model_vdbg_dir, exist_ok=True)
net.module.visual_dbg(step, model_vdbg_dir)
def compute_fea_loss(self, real, fake): def compute_fea_loss(self, real, fake):
with torch.no_grad(): with torch.no_grad():