diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 60e22bb4..3dc0f0e2 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -489,8 +489,7 @@ class SRGANModel(BaseModel): pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) - # Uncomment to compute a discriminator loss against the grad branch. - #l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) + l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) elif self.opt['train']['gan_type'] == 'ragan': pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() l_g_gan_grad = self.l_gan_w * (