diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 46e5335a..5fe1a8ec 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -96,7 +96,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) + self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) debug_index += 1 gen_out = gen(*input) if isinstance(gen_out, torch.Tensor): @@ -117,7 +117,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): input[self.recurrent_index ] = recurrent_input if self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) + self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index) debug_index += 1 gen_out = gen(*input) if isinstance(gen_out, torch.Tensor): @@ -127,13 +127,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector): return {self.output: results} - def produce_teco_visual_debugs(self, gen_input, it): + def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it): if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) - torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) - torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,))) + torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,))) + torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,))) class FlowAdjustment(Injector):