From 26a6a5d512338fcd8a38c69629c5b40efba08cbb Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 5 Aug 2020 12:08:15 -0600 Subject: [PATCH] Compute grad GAN loss against both the branch and final target, simplify pixel loss Also fixes a memory leak issue where we weren't detaching our loss stats when logging them. This stabilizes memory usage substantially. --- codes/models/SRGAN_model.py | 60 ++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 75e6061e..6af272bd 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -393,14 +393,13 @@ class SRGANModel(BaseModel): var_ref_skips = [] for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): if self.spsr_enabled: + using_gan_img = False # SPSR models have outputs from three different branches. fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L) fea_GenOut = fake_GenOut - using_gan_img = False + self.spsr_grad_GenOut.append(fake_H_branch) # Get image gradients for later use. fake_H_grad = self.get_grad_nopadding(fake_GenOut) - var_H_grad_nopadding = self.get_grad_nopadding(var_H) - self.spsr_grad_GenOut.append(fake_H_branch) else: if random.random() > self.gan_lq_img_use_prob: fea_GenOut, fake_GenOut = self.netG(var_L) @@ -428,11 +427,16 @@ class SRGANModel(BaseModel): l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss + var_H_grad_nopadding = self.get_grad_nopadding(var_H) l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding) l_g_total += l_g_pix_grad if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss - l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, - var_H_grad_nopadding) + # The point of this loss is that the core structure of the grad image does not get mutated. Therefore, + # downsample and compare against the input. The GAN loss will take care of the details in HR-space. + var_L_grad = self.get_grad_nopadding(var_L) + downsampled_H_branch = F.interpolate(fake_H_branch, size=var_L_grad.shape[2:], mode="nearest") + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch, + var_L_grad) l_g_total += l_g_pix_grad_branch if self.fdpl_enabled and not using_gan_img: l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) @@ -480,14 +484,19 @@ class SRGANModel(BaseModel): if self.spsr_enabled and self.cri_grad_gan: pred_g_fake_grad = self.netD_grad(fake_H_grad) + pred_g_fake_grad_branch = self.netD_grad(fake_H_branch) if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) + l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True) elif self.opt['train']['gan_type'] == 'ragan': pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach() l_g_gan_grad = self.l_gan_w * ( self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 - l_g_total += l_g_gan_grad + l_g_gan_grad_branch = self.l_gan_w * ( + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad_branch), False) + + self.cri_gan(pred_g_fake_grad_branch - torch.mean(pred_g_real_grad), True)) / 2 + l_g_total += l_g_gan_grad + l_g_gan_grad_branch # Scale the loss down by the batch factor. l_g_total_log = l_g_total @@ -703,36 +712,39 @@ class SRGANModel(BaseModel): # Log metrics if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if self.cri_pix and l_g_pix_log is not None: - self.add_log_entry('l_g_pix', l_g_pix_log.item()) + self.add_log_entry('l_g_pix', l_g_pix_log.detach().item()) if self.fdpl_enabled and l_g_fdpl is not None: - self.add_log_entry('l_g_fdpl', l_g_fdpl.item()) + self.add_log_entry('l_g_fdpl', l_g_fdpl.detach().item()) if self.cri_fea and l_g_fea_log is not None: self.add_log_entry('feature_weight', fea_w) - self.add_log_entry('l_g_fea', l_g_fea_log.item()) - self.add_log_entry('l_g_fix_disc', l_g_fix_disc.item()) + self.add_log_entry('l_g_fea', l_g_fea_log.detach().item()) + self.add_log_entry('l_g_fix_disc', l_g_fix_disc.detach().item()) if self.l_gan_w > 0: - self.add_log_entry('l_g_gan', l_g_gan_log.item()) - self.add_log_entry('l_g_total', l_g_total_log.item()) + self.add_log_entry('l_g_gan', l_g_gan_log.detach().item()) + self.add_log_entry('l_g_total', l_g_total_log.detach().item()) if self.opt['train']['gan_type'] == 'pixgan_fea': - self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor) - self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) - self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) - self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fea_fake', l_d_fea_fake.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor) + self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor) if self.spsr_enabled: if self.cri_pix_grad: - self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.item()) + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item()) if self.cri_pix_branch: - self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item()) + self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.detach().item()) + if self.cri_grad_gan: + self.add_log_entry('l_g_gan_grad', l_g_gan_grad.detach().item()) + self.add_log_entry('l_g_gan_grad_branch', l_g_gan_grad_branch.detach().item()) if self.l_gan_w > 0 and step >= self.G_warmup: - self.add_log_entry('l_d_real', l_d_real_log.item()) - self.add_log_entry('l_d_fake', l_d_fake_log.item()) + self.add_log_entry('l_d_real', l_d_real_log.detach().item()) + self.add_log_entry('l_d_fake', l_d_fake_log.detach().item()) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) + self.add_log_entry('D_diff', torch.mean(pred_d_fake.detach()) - torch.mean(pred_d_real.detach())) if self.spsr_enabled: - self.add_log_entry('l_d_real_grad', l_d_real_grad.item()) - self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item()) + self.add_log_entry('l_d_real_grad', l_d_real_grad.detach().item()) + self.add_log_entry('l_d_fake_grad', l_d_fake_grad.detach().item()) self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach())) - self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad)) + self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad.detach()) - torch.mean(pred_d_real_grad.detach())) # Log learning rates. for i, pg in enumerate(self.optimizer_G.param_groups):