diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 17e0b59b..5d4f4274 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -395,7 +395,6 @@ class SRGANModel(BaseModel): using_gan_img = False # Get image gradients for later use. fake_H_grad = self.get_grad_nopadding(fake_GenOut) - var_ref_grad = self.get_grad_nopadding(var_ref) var_H_grad_nopadding = self.get_grad_nopadding(var_H) self.spsr_grad_GenOut.append(grad_LR) else: @@ -477,10 +476,7 @@ class SRGANModel(BaseModel): if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss pred_g_fake_grad = self.netD_grad(fake_H_grad) - pred_d_real_grad = self.netD_grad(var_ref_grad).detach() - - l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + - self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2 + l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) l_g_total += l_g_gan_grad # Scale the loss down by the batch factor. @@ -508,7 +504,6 @@ class SRGANModel(BaseModel): noise = torch.randn_like(var_ref) * noise_theta noise.to(self.device) - self.optimizer_D.zero_grad() real_disc_images = [] fake_disc_images = [] for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): @@ -533,6 +528,7 @@ class SRGANModel(BaseModel): fake_H = fake_H + noise l_d_fea_real = 0 l_d_fea_fake = 0 + self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'pixgan_fea': # Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better. disc_fea_scale = .1 @@ -548,14 +544,14 @@ class SRGANModel(BaseModel): pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() # fake pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() + + l_d_total = (l_d_real + l_d_fake) / 2 + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() if 'pixgan' in self.opt['train']['gan_type']: pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) @@ -599,15 +595,15 @@ class SRGANModel(BaseModel): l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor l_d_real += l_d_fea_real - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() # fake pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake += l_d_fea_fake - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() + + l_d_total = (l_d_real + l_d_fake) / 2 + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real)) pdr = pdr / torch.max(pdr) @@ -643,7 +639,6 @@ class SRGANModel(BaseModel): print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,)) _t = time() - # Append var_ref here, so that we can inspect the alterations the disc made if pixgan var_ref_skips.append(var_ref.detach()) self.fake_H.append(fake_H.detach()) self.optimizer_D.step() @@ -657,20 +652,19 @@ class SRGANModel(BaseModel): for p in self.netD_grad.parameters(): p.requires_grad = True self.optimizer_D_grad.zero_grad() - - for var_ref, fake_H in zip(self.var_ref, self.fake_H): + for var_ref, fake_H in zip(self.var_ref_skips, self.fake_H): fake_H_grad = self.get_grad_nopadding(fake_H) var_ref_grad = self.get_grad_nopadding(var_ref) pred_d_real_grad = self.netD_grad(var_ref_grad) pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) + l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor + l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor elif self.opt['train']['gan_type'] == 'pixgan': real = torch.ones_like(pred_d_real_grad) fake = torch.zeros_like(pred_d_fake_grad) - l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real) - l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake) + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) + l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad /= self.mega_batch_factor with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: @@ -932,4 +926,5 @@ class SRGANModel(BaseModel): def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) - self.save_network(self.netD_grad, 'D_grad', iter_step) + if self.spsr_enabled: + self.save_network(self.netD_grad, 'D_grad', iter_step)