Add capacity for models to self-report visuals
This commit is contained in:
parent
dca5cddb3b
commit
3c6e600e48
|
@ -230,25 +230,30 @@ class ExtensibleTrainer(BaseModel):
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||||
for v in self.opt['logger']['visuals']:
|
for v in self.opt['logger']['visuals']:
|
||||||
if v not in state.keys():
|
if v not in state.keys():
|
||||||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||||
if step % self.opt['logger']['visual_debug_rate'] == 0:
|
for i, dbgv in enumerate(state[v]):
|
||||||
for i, dbgv in enumerate(state[v]):
|
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
||||||
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
rdbgv = dbgv[:, rvi]
|
||||||
rdbgv = dbgv[:, rvi]
|
if rdbgv.shape[1] > 3:
|
||||||
if rdbgv.shape[1] > 3:
|
rdbgv = rdbgv[:, :3, :, :]
|
||||||
rdbgv = rdbgv[:, :3, :, :]
|
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
|
||||||
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
|
||||||
else:
|
|
||||||
if dbgv.shape[1] > 3:
|
|
||||||
dbgv = dbgv[:,:3,:,:]
|
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||||
|
else:
|
||||||
|
if dbgv.shape[1] > 3:
|
||||||
|
dbgv = dbgv[:,:3,:,:]
|
||||||
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
|
utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||||
|
# Some models have their own specific visual debug routines.
|
||||||
|
for net_name, net in self.networks.items():
|
||||||
|
if hasattr(net.module, "visual_dbg"):
|
||||||
|
model_vdbg_dir = os.path.join(sample_save_path, net_name)
|
||||||
|
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||||
|
net.module.visual_dbg(step, model_vdbg_dir)
|
||||||
|
|
||||||
def compute_fea_loss(self, real, fake):
|
def compute_fea_loss(self, real, fake):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user