From be272248af398b976a9af4e0772b999e5c7ceb73 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 5 Aug 2020 16:47:21 -0600 Subject: [PATCH] More RAGAN fixes --- codes/models/SRGAN_model.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 6af272bd..f59d44bf 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -489,7 +489,7 @@ class SRGANModel(BaseModel): l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) elif self.opt['train']['gan_type'] == 'ragan': - pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach() + pred_g_real_grad = self.netD_grad(self.get_grad_nopadding(var_ref)).detach() l_g_gan_grad = self.l_gan_w * ( self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 @@ -631,7 +631,7 @@ class SRGANModel(BaseModel): fake_disc_images.append(pdf.view(disc_output_shape)) elif self.opt['train']['gan_type'] == 'ragan': - pred_d_fake = self.netD(fake_H).detach() + pred_d_fake = self.netD(fake_H) pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_real_log = l_d_real @@ -655,10 +655,10 @@ class SRGANModel(BaseModel): p.requires_grad = True self.optimizer_D_grad.zero_grad() for var_ref, fake_H in zip(var_ref_skips, self.fake_H): - fake_H_grad = self.get_grad_nopadding(fake_H) + fake_H_grad = self.get_grad_nopadding(fake_H).detach() var_ref_grad = self.get_grad_nopadding(var_ref) pred_d_real_grad = self.netD_grad(var_ref_grad) - pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G + pred_d_fake_grad = self.netD_grad(fake_H_grad) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor @@ -668,10 +668,8 @@ class SRGANModel(BaseModel): l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) elif self.opt['train']['gan_type'] == 'ragan': - pred_g_fake_grad = self.netD_grad(fake_H_grad) - pred_d_real_grad = self.netD_grad(var_ref_grad).detach() - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True) - l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False) + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad /= self.mega_batch_factor @@ -743,8 +741,8 @@ class SRGANModel(BaseModel): if self.spsr_enabled: self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item()) self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item()) - self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) + self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach())) + self.add_log_entry('D_diff_grad', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) # Log learning rates. for i, pg in enumerate(self.optimizer_G.param_groups):