diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 8f4e2e9e..32b627d3 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -230,25 +230,30 @@ class ExtensibleTrainer(BaseModel): [e.after_optimize(state) for e in self.experiments] # Record visual outputs for usage in debugging and testing. - if 'visuals' in self.opt['logger'].keys() and self.rank <= 0: + if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") for v in self.opt['logger']['visuals']: if v not in state.keys(): continue # This can happen for several reasons (ex: 'after' defs), just ignore it. - if step % self.opt['logger']['visual_debug_rate'] == 0: - for i, dbgv in enumerate(state[v]): - if 'recurrent_visual_indices' in self.opt['logger'].keys(): - for rvi in self.opt['logger']['recurrent_visual_indices']: - rdbgv = dbgv[:, rvi] - if rdbgv.shape[1] > 3: - rdbgv = rdbgv[:, :3, :, :] - os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) - else: - if dbgv.shape[1] > 3: - dbgv = dbgv[:,:3,:,:] + for i, dbgv in enumerate(state[v]): + if 'recurrent_visual_indices' in self.opt['logger'].keys(): + for rvi in self.opt['logger']['recurrent_visual_indices']: + rdbgv = dbgv[:, rvi] + if rdbgv.shape[1] > 3: + rdbgv = rdbgv[:, :3, :, :] os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) + utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) + else: + if dbgv.shape[1] > 3: + dbgv = dbgv[:,:3,:,:] + os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) + utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) + # Some models have their own specific visual debug routines. + for net_name, net in self.networks.items(): + if hasattr(net.module, "visual_dbg"): + model_vdbg_dir = os.path.join(sample_save_path, net_name) + os.makedirs(model_vdbg_dir, exist_ok=True) + net.module.visual_dbg(step, model_vdbg_dir) def compute_fea_loss(self, real, fake): with torch.no_grad():