Report l_g_gan_grad before weight multiplication

This commit is contained in:
James Betker 2020-08-20 11:57:53 -06:00
parent 9d77a4db2e
commit a498d7b1b3

View File

@ -504,10 +504,10 @@ class SRGANModel(BaseModel):
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
l_g_gan_grad = self.l_gan_w * ( l_g_gan_grad = self.l_gan_grad_w * (
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
l_g_gan_grad_branch = self.l_gan_w * ( l_g_gan_grad_branch = self.l_gan_grad_w * (
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) +
self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2 self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2
l_g_total += l_g_gan_grad + l_g_gan_grad_branch l_g_total += l_g_gan_grad + l_g_gan_grad_branch
@ -776,8 +776,8 @@ class SRGANModel(BaseModel):
if self.cri_pix_branch: if self.cri_pix_branch:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item()) self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item())
if self.cri_grad_gan: if self.cri_grad_gan:
self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item()) self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item() / self.l_gan_grad_w)
self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item()) self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item() / self.l_gan_grad_w)
if self.l_gan_w > 0 and step >= self.G_warmup: if self.l_gan_w > 0 and step >= self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.detach().item()) self.add_log_entry('l_d_real', l_d_real_log.detach().item())
self.add_log_entry('l_d_fake', l_d_fake_log.detach().item()) self.add_log_entry('l_d_fake', l_d_fake_log.detach().item())