forked from mrq/DL-Art-School
Report l_g_gan_grad before weight multiplication
This commit is contained in:
parent
9d77a4db2e
commit
a498d7b1b3
|
@ -504,10 +504,10 @@ class SRGANModel(BaseModel):
|
|||
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
|
||||
l_g_gan_grad = self.l_gan_w * (
|
||||
l_g_gan_grad = self.l_gan_grad_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_gan_grad_branch = self.l_gan_w * (
|
||||
l_g_gan_grad_branch = self.l_gan_grad_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) +
|
||||
self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_total += l_g_gan_grad + l_g_gan_grad_branch
|
||||
|
@ -776,8 +776,8 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_pix_branch:
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item())
|
||||
if self.cri_grad_gan:
|
||||
self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item())
|
||||
self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item())
|
||||
self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item() / self.l_gan_grad_w)
|
||||
self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item() / self.l_gan_grad_w)
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup:
|
||||
self.add_log_entry('l_d_real', l_d_real_log.detach().item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.detach().item())
|
||||
|
|
Loading…
Reference in New Issue
Block a user