More RAGAN fixes
This commit is contained in:
parent
b8a4df0a0a
commit
299ee13988
|
@ -30,10 +30,6 @@ class SRGANModel(BaseModel):
|
|||
train_opt = opt['train']
|
||||
self.spsr_enabled = 'spsr' in opt['model']
|
||||
|
||||
# Only pixgan and gan are currently supported in spsr_mode
|
||||
if self.spsr_enabled:
|
||||
assert train_opt['gan_type'] == 'pixgan' or train_opt['gan_type'] == 'gan'
|
||||
|
||||
# define networks and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device)
|
||||
if self.is_train:
|
||||
|
@ -488,7 +484,7 @@ class SRGANModel(BaseModel):
|
|||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach()
|
||||
l_g_gan = self.l_gan_w * (
|
||||
l_g_gan_grad = self.l_gan_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
@ -629,7 +625,9 @@ class SRGANModel(BaseModel):
|
|||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_real_log = l_d_real
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
l_d_fake_log = l_d_fake
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
l_d_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
|
@ -661,8 +659,8 @@ class SRGANModel(BaseModel):
|
|||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
|
||||
|
|
|
@ -388,19 +388,12 @@ class SPSRNetSimplifiedNoSkip(nn.Module):
|
|||
x_ori = x
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = self.model_shortcut_blk[i + 15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = self.model_shortcut_blk[20:](x)
|
||||
x = self.feature_lr_conv(x)
|
||||
|
@ -430,7 +423,6 @@ class SPSRNetSimplifiedNoSkip(nn.Module):
|
|||
x_out = self._branch_pretrain_concat(x__branch_pretrain_cat)
|
||||
x_out = self._branch_pretrain_HR_conv0(x_out)
|
||||
x_out = self._branch_pretrain_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user