diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 39a01e5b..75e6061e 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -30,10 +30,6 @@ class SRGANModel(BaseModel): train_opt = opt['train'] self.spsr_enabled = 'spsr' in opt['model'] - # Only pixgan and gan are currently supported in spsr_mode - if self.spsr_enabled: - assert train_opt['gan_type'] == 'pixgan' or train_opt['gan_type'] == 'gan' - # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if self.is_train: @@ -488,7 +484,7 @@ class SRGANModel(BaseModel): l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True) elif self.opt['train']['gan_type'] == 'ragan': pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach() - l_g_gan = self.l_gan_w * ( + l_g_gan_grad = self.l_gan_w * ( self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2 l_g_total += l_g_gan_grad @@ -629,7 +625,9 @@ class SRGANModel(BaseModel): pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) + l_d_real_log = l_d_real l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) + l_d_fake_log = l_d_fake l_d_total = (l_d_real + l_d_fake) / 2 l_d_total /= self.mega_batch_factor with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled: @@ -661,8 +659,8 @@ class SRGANModel(BaseModel): l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real) l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake) elif self.opt['train']['gan_type'] == 'ragan': - pred_g_fake_grad = self.netD_grad(self.fake_H_grad) - pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach() + pred_g_fake_grad = self.netD_grad(fake_H_grad) + pred_d_real_grad = self.netD_grad(var_ref_grad).detach() l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True) l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index b4f5da7d..577a17dc 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -388,19 +388,12 @@ class SPSRNetSimplifiedNoSkip(nn.Module): x_ori = x for i in range(5): x = self.model_shortcut_blk[i](x) - x_fea1 = x - for i in range(5): x = self.model_shortcut_blk[i + 5](x) - x_fea2 = x - for i in range(5): x = self.model_shortcut_blk[i + 10](x) - x_fea3 = x - for i in range(5): x = self.model_shortcut_blk[i + 15](x) - x_fea4 = x x = self.model_shortcut_blk[20:](x) x = self.feature_lr_conv(x) @@ -430,7 +423,6 @@ class SPSRNetSimplifiedNoSkip(nn.Module): x_out = self._branch_pretrain_concat(x__branch_pretrain_cat) x_out = self._branch_pretrain_HR_conv0(x_out) x_out = self._branch_pretrain_HR_conv1(x_out) - ######### return x_out_branch, x_out, x_grad