forked from mrq/DL-Art-School
Enable RAGAN in SPSR, retrofit old RAGAN for efficiency
This commit is contained in:
parent
3ab39f0d22
commit
b8a4df0a0a
|
@ -482,9 +482,15 @@ class SRGANModel(BaseModel):
|
|||
l_g_gan_log = l_g_gan / self.l_gan_w
|
||||
l_g_total += l_g_gan
|
||||
|
||||
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
|
||||
if self.spsr_enabled and self.cri_grad_gan:
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_real_grad = self.netD(self.get_grad_nopadding(var_ref)).detach()
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_g_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_gan(pred_g_fake_grad - torch.mean(pred_g_real_grad), True)) / 2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
|
@ -622,30 +628,12 @@ class SRGANModel(BaseModel):
|
|||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
|
||||
if _profile:
|
||||
print("Double disc forward (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
l_d_real_log = l_d_real * self.mega_batch_factor * 2
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
|
||||
if _profile:
|
||||
print("Disc backward 1 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
l_d_fake_log = l_d_fake * self.mega_batch_factor * 2
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
if _profile:
|
||||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
l_d_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
var_ref_skips.append(var_ref.detach())
|
||||
self.fake_H.append(fake_H.detach())
|
||||
self.optimizer_D.step()
|
||||
|
@ -672,6 +660,12 @@ class SRGANModel(BaseModel):
|
|||
fake = torch.zeros_like(pred_d_fake_grad)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||
|
|
|
@ -382,7 +382,6 @@ class SPSRNetSimplifiedNoSkip(nn.Module):
|
|||
self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model_fea_conv(x)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user