diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index f59d44bf..60e22bb4 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -159,6 +159,8 @@ class SRGANModel(BaseModel): # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] + if train_opt['gan_type'] == 'pixgan': + self.do_pixgan_swap = True if 'do_pixgan_swap' not in train_opt.keys() else train_opt['do_pixgan_swap'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 @@ -431,12 +433,12 @@ class SRGANModel(BaseModel): l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding) l_g_total += l_g_pix_grad if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss - # The point of this loss is that the core structure of the grad image does not get mutated. Therefore, - # downsample and compare against the input. The GAN loss will take care of the details in HR-space. - var_L_grad = self.get_grad_nopadding(var_L) - downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest") + grad_truth = self.get_grad_nopadding(var_L) + downsampled_H_branch = fake_H_branch + if grad_truth.shape != fake_H_branch.shape: + downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest") l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch, - var_L_grad) + grad_truth) l_g_total += l_g_pix_grad_branch if self.fdpl_enabled and not using_gan_img: l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) @@ -487,7 +489,8 @@ class SRGANModel(BaseModel): pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) - l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) + # Uncomment to compute a discriminator loss against the grad branch. + #l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) elif self.opt['train']['gan_type'] == 'ragan': pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() l_g_gan_grad = self.l_gan_w * ( @@ -576,7 +579,7 @@ class SRGANModel(BaseModel): b, _, w, h = var_ref.shape real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) - if not self.disjoint_data: + if self.do_pixgan_swap and not self.disjoint_data: # randomly determine portions of the image to swap to keep the discriminator honest. SWAP_MAX_DIM = w // 4 SWAP_MIN_DIM = 16 @@ -654,22 +657,27 @@ class SRGANModel(BaseModel): for p in self.netD_grad.parameters(): p.requires_grad = True self.optimizer_D_grad.zero_grad() - for var_ref, fake_H in zip(var_ref_skips, self.fake_H): + for var_ref, fake_H, fake_H_grad_branch in zip(var_ref_skips, self.fake_H, self.spsr_grad_GenOut): fake_H_grad = self.get_grad_nopadding(fake_H).detach() var_ref_grad = self.get_grad_nopadding(var_ref) pred_d_real_grad = self.netD_grad(var_ref_grad) - pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G + pred_d_fake_grad = self.netD_grad(fake_H_grad) # Tensor already detached above. + # var_ref and fake_H already has noise added to it. We **must** add noise to fake_H_grad_branch too. + fake_H_grad_branch = fake_H_grad_branch.detach() + noise + pred_d_fake_grad_branch = self.netD_grad(fake_H_grad_branch) if self.opt['train']['gan_type'] == 'gan': - l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor - l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor + l_d_real_grad = self.cri_gan(pred_d_real_grad, True) + l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2 elif self.opt['train']['gan_type'] == 'pixgan': real = torch.ones_like(pred_d_real_grad) fake = torch.zeros_like(pred_d_fake_grad) l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) + l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad, fake) + \ + self.cri_grad_gan(pred_d_fake_grad_branch, fake)) / 2 elif self.opt['train']['gan_type'] == 'ragan': l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + l_d_fake_grad = (self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + \ + self.cri_grad_gan(pred_d_fake_grad_branch - torch.mean(pred_d_real_grad), False)) / 2 l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad /= self.mega_batch_factor