Comptue gan_grad_branch....

This commit is contained in:
James Betker 2020-08-06 12:11:40 -06:00
parent 30b16d5235
commit fd7b6ca0a9

View File

@ -489,8 +489,7 @@ class SRGANModel(BaseModel):
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
# Uncomment to compute a discriminator loss against the grad branch.
#l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
l_g_gan_grad = self.l_gan_w * (