diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 902fa226..0bc2fce4 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -504,10 +504,10 @@ class SRGANModel(BaseModel): l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) elif self.opt['train']['gan_type'] == 'ragan': pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() - l_g_gan_grad = self.l_gan_w * ( + l_g_gan_grad = self.l_gan_grad_w * ( self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 - l_g_gan_grad_branch = self.l_gan_w * ( + l_g_gan_grad_branch = self.l_gan_grad_w * ( self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) + self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2 l_g_total += l_g_gan_grad + l_g_gan_grad_branch @@ -776,8 +776,8 @@ class SRGANModel(BaseModel): if self.cri_pix_branch: self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item()) if self.cri_grad_gan: - self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item()) - self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item()) + self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item() / self.l_gan_grad_w) + self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item() / self.l_gan_grad_w) if self.l_gan_w > 0 and step >= self.G_warmup: self.add_log_entry('l_d_real', l_d_real_log.detach().item()) self.add_log_entry('l_d_fake', l_d_fake_log.detach().item())