More RAGAN fixes

This commit is contained in:
James Betker 2020-08-05 16:47:21 -06:00
parent 26a6a5d512
commit be272248af

View File

@ -489,7 +489,7 @@ class SRGANModel(BaseModel):
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach() pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach()
l_g_gan_grad = self.l_gan_w * ( l_g_gan_grad = self.l_gan_w * (
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
@ -631,7 +631,7 @@ class SRGANModel(BaseModel):
fake_disc_images.append(pdf.view(disc_output_shape)) fake_disc_images.append(pdf.view(disc_output_shape))
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_d_fake = self.netD(fake_H).detach() pred_d_fake = self.netD(fake_H)
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_real_log = l_d_real l_d_real_log = l_d_real
@ -655,10 +655,10 @@ class SRGANModel(BaseModel):
p.requires_grad = True p.requires_grad = True
self.optimizer_D_grad.zero_grad() self.optimizer_D_grad.zero_grad()
for var_ref, fake_H in zip(var_ref_skips, self.fake_H): for var_ref, fake_H in zip(var_ref_skips, self.fake_H):
fake_H_grad = self.get_grad_nopadding(fake_H) fake_H_grad = self.get_grad_nopadding(fake_H).detach()
var_ref_grad = self.get_grad_nopadding(var_ref) var_ref_grad = self.get_grad_nopadding(var_ref)
pred_d_real_grad = self.netD_grad(var_ref_grad) pred_d_real_grad = self.netD_grad(var_ref_grad)
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
@ -668,10 +668,8 @@ class SRGANModel(BaseModel):
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_g_fake_grad = self.netD_grad(fake_H_grad) l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
pred_d_real_grad = self.netD_grad(var_ref_grad).detach() l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad /= self.mega_batch_factor l_d_total_grad /= self.mega_batch_factor
@ -743,8 +741,8 @@ class SRGANModel(BaseModel):
if self.spsr_enabled: if self.spsr_enabled:
self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item()) self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item())
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item()) self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach())) self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) self.add_log_entry('D_diff_grad', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach()))
# Log learning rates. # Log learning rates.
for i, pg in enumerate(self.optimizer_G.param_groups): for i, pg in enumerate(self.optimizer_G.param_groups):