diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 59d1b893..d8e7d638 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -484,7 +484,7 @@ class StackedSwitchGenerator5Layer(nn.Module): prefix = "amap_%i_a%i_%%i.png" [save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", + torchvision.utils.save_image(self.lr[:,:3], os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) def get_debug_values(self, step, net_name): diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index a9ca6d1f..e6c6fa92 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -22,21 +22,23 @@ def create_teco_injector(opt, env): return RecurrentImageGeneratorSequenceInjector(opt, env) return None -def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler): +def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): triplet = input_list[:, index:index+3] # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. with torch.no_grad(): - first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2)) + first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2)) first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') - last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2)) + last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2)) last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()), triplet[:,1], resampler(triplet[:,2].float(), last_flow.float())] - flow_triplet = torch.stack(flow_triplet, dim=2) - combined = torch.cat([triplet, flow_triplet], dim=2) + flow_triplet = torch.stack(flow_triplet, dim=1) + combined = torch.cat([triplet, flow_triplet], dim=1) b, f, c, h, w = combined.shape - return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. + combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. + # Apply margin + return combined[:, :, margin:-margin, margin:-margin] def extract_inputs_index(inputs, i): @@ -72,6 +74,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): # Go forward in the sequence first. first_step = True b, f, c, h, w = inputs[self.input_lq_index].shape + debug_index = 0 for i in range(f): input = extract_inputs_index(inputs, i) if first_step: @@ -84,6 +87,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector): # Resample does not work in FP16. recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + if self.env['step'] % 20 == 0: + self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) + debug_index += 1 gen_out = gen(*input) recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input) @@ -98,12 +104,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector): flowfield = flow(flow_input) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + if self.env['step'] % 20 == 0: + self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) + debug_index += 1 gen_out = gen(*input) recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input) return {self.output: results} + def produce_teco_visual_debugs(self, gen_input, it): + base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) + os.makedirs(base_path, exist_ok=True) + torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) + torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,))) + # This is the temporal discriminator loss from TecoGAN. # @@ -128,6 +143,7 @@ class TecoGanLoss(ConfigurableLoss): self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() self.for_generator = opt['for_generator'] + self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. def forward(self, _, state): net = self.env['discriminators'][self.opt['discriminator']] @@ -138,16 +154,18 @@ class TecoGanLoss(ConfigurableLoss): lr = state[self.opt['lr_inputs']] l_total = 0 for i in range(sequence_len - 2): - real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler) - fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler) + real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) + fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) d_fake = net(fake_sext) + d_real = net(real_sext) + self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) - if self.for_generator and self.env['step'] % 100 == 0: + if self.for_generator and self.env['step'] % 20 == 0: self.produce_teco_visual_debugs(fake_sext, 'fake', i) self.produce_teco_visual_debugs(real_sext, 'real', i) if self.opt['gan_type'] in ['gan', 'pixgan']: - self.metrics.append(("d_fake", torch.mean(d_fake))) l_fake = self.criterion(d_fake, self.for_generator) if not self.for_generator: l_real = self.criterion(d_real, True) @@ -155,7 +173,6 @@ class TecoGanLoss(ConfigurableLoss): l_real = 0 l_total += l_fake + l_real elif self.opt['gan_type'] == 'ragan': - d_real = net(real_sext) d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + @@ -166,11 +183,11 @@ class TecoGanLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, sext, lbl, it): - base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl) + base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) os.makedirs(base_path, exist_ok=True) - lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow'] + lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] for i in range(6): - torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it))) + torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i]))) # This loss doesn't have a real entry - only fakes are used. @@ -188,13 +205,13 @@ class PingPongLoss(ConfigurableLoss): late = fake[-i] l_total += self.criterion(early, late) - if self.env['step'] % 100 == 0: + if self.env['step'] % 20 == 0: self.produce_teco_visual_debugs(fake) return l_total def produce_teco_visual_debugs(self, imglist): - base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step'])) + base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) assert isinstance(imglist, list) for i, img in enumerate(imglist):