diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 66fc05a6..19665e67 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -239,16 +239,17 @@ class SRGANModel(BaseModel): if step % self.l_fea_w_decay_steps == 0: self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) - if self.opt['train']['gan_type'] == 'gan': - pred_g_fake = self.netD(fake_GenOut) - l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) - elif self.opt['train']['gan_type'] == 'ragan': - pred_d_real = self.netD(var_ref).detach() - pred_g_fake = self.netD(fake_GenOut) - l_g_gan = self.l_gan_w * ( - self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - l_g_total += l_g_gan + if self.l_gan_w > 0: + if self.opt['train']['gan_type'] == 'gan': + pred_g_fake = self.netD(fake_GenOut) + l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_d_real = self.netD(var_ref).detach() + pred_g_fake = self.netD(fake_GenOut) + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + l_g_total += l_g_gan # Scale the loss down by the batch factor. l_g_total = l_g_total / self.mega_batch_factor @@ -258,51 +259,52 @@ class SRGANModel(BaseModel): self.optimizer_G.step() # D - for p in self.netD.parameters(): - p.requires_grad = True + if self.l_gan_w > 0: + for p in self.netD.parameters(): + p.requires_grad = True - noise = torch.randn_like(var_ref[0]) * noise_theta - noise.to(self.device) - self.optimizer_D.zero_grad() - for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): - # Re-compute generator outputs (post-update). - with torch.no_grad(): - fake_H = self.netG(var_L) - # The following line detaches all generator outputs that are not None. - fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) + noise = torch.randn_like(var_ref[0]) * noise_theta + noise.to(self.device) + self.optimizer_D.zero_grad() + for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): + # Re-compute generator outputs (post-update). + with torch.no_grad(): + fake_H = self.netG(var_L) + # The following line detaches all generator outputs that are not None. + fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) - # Apply noise to the inputs to slow discriminator convergence. - var_ref = (var_ref[0] + noise,) + var_ref[1:] - fake_H = (fake_H[0] + noise,) + fake_H[1:] - if self.opt['train']['gan_type'] == 'gan': - # need to forward and backward separately, since batch norm statistics differ - # real - pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() - # fake - pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() - elif self.opt['train']['gan_type'] == 'ragan': - # pred_d_real = self.netD(var_ref) - # pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G - # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - # l_d_total = (l_d_real + l_d_fake) / 2 - # l_d_total.backward() - pred_d_fake = self.netD(fake_H).detach() - pred_d_real = self.netD(var_ref) - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() - pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() - self.optimizer_D.step() + # Apply noise to the inputs to slow discriminator convergence. + var_ref = (var_ref[0] + noise,) + var_ref[1:] + fake_H = (fake_H[0] + noise,) + fake_H[1:] + if self.opt['train']['gan_type'] == 'gan': + # need to forward and backward separately, since batch norm statistics differ + # real + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() + # fake + pred_d_fake = self.netD(fake_H) + l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() + elif self.opt['train']['gan_type'] == 'ragan': + # pred_d_real = self.netD(var_ref) + # pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G + # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + # l_d_total = (l_d_real + l_d_fake) / 2 + # l_d_total.backward() + pred_d_fake = self.netD(fake_H).detach() + pred_d_real = self.netD(var_ref) + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() + pred_d_fake = self.netD(fake_H) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() + self.optimizer_D.step() # Log sample images from first microbatch. if step % 50 == 0: