From b8a4df0a0a1b2f0f0b182e99563adb45d2b26c4a Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 5 Aug 2020 10:33:09 -0600 Subject: [PATCH] Enable RAGAN in SPSR, retrofit old RAGAN for efficiency --- codes/models/SRGAN_model.py | 46 ++++++++++++++------------------- codes/models/archs/SPSR_arch.py | 1 - 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index f2d3d2dd..39a01e5b 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -482,9 +482,15 @@ class SRGANModel(BaseModel): l_g_gan_log = l_g_gan / self.l_gan_w l_g_total += l_g_gan - if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss + if self.spsr_enabled and self.cri_grad_gan: pred_g_fake_grad = self.netD_grad(fake_H_grad) - l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) + if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: + l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) + elif self.opt['train']['gan_type'] == 'ragan': + pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach() + l_g_gan = self.l_gan_w * ( + self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 l_g_total += l_g_gan_grad # Scale the loss down by the batch factor. @@ -622,30 +628,12 @@ class SRGANModel(BaseModel): elif self.opt['train']['gan_type'] == 'ragan': pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) - - if _profile: - print("Double disc forward (RAGAN) %f" % (time() - _t,)) - _t = time() - - l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor - l_d_real_log = l_d_real * self.mega_batch_factor * 2 - with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: - l_d_real_scaled.backward() - - if _profile: - print("Disc backward 1 (RAGAN) %f" % (time() - _t,)) - _t = time() - - pred_d_fake = self.netD(fake_H) - l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor - l_d_fake_log = l_d_fake * self.mega_batch_factor * 2 - with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: - l_d_fake_scaled.backward() - - if _profile: - print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,)) - _t = time() - + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total /= self.mega_batch_factor + with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: + l_d_total_scaled.backward() var_ref_skips.append(var_ref.detach()) self.fake_H.append(fake_H.detach()) self.optimizer_D.step() @@ -672,6 +660,12 @@ class SRGANModel(BaseModel): fake = torch.zeros_like(pred_d_fake_grad) l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) + elif self.opt['train']['gan_type'] == 'ragan': + pred_g_fake_grad = self.netD_grad(self.fake_H_grad) + pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach() + l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True) + l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False) + l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad /= self.mega_batch_factor with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled: diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 56eb944b..b4f5da7d 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -382,7 +382,6 @@ class SPSRNetSimplifiedNoSkip(nn.Module): self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) def forward(self, x): - x_grad = self.get_g_nopadding(x) x = self.model_fea_conv(x)