Apply fixes to grad discriminator
This commit is contained in:
parent
96d66f51c5
commit
6d25bcd5df
|
@ -395,7 +395,6 @@ class SRGANModel(BaseModel):
|
|||
using_gan_img = False
|
||||
# Get image gradients for later use.
|
||||
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
|
||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
self.spsr_grad_GenOut.append(grad_LR)
|
||||
else:
|
||||
|
@ -477,10 +476,7 @@ class SRGANModel(BaseModel):
|
|||
|
||||
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
|
@ -508,7 +504,6 @@ class SRGANModel(BaseModel):
|
|||
|
||||
noise = torch.randn_like(var_ref) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
real_disc_images = []
|
||||
fake_disc_images = []
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
|
@ -533,6 +528,7 @@ class SRGANModel(BaseModel):
|
|||
fake_H = fake_H + noise
|
||||
l_d_fea_real = 0
|
||||
l_d_fea_fake = 0
|
||||
self.optimizer_D.zero_grad()
|
||||
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
||||
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
|
||||
disc_fea_scale = .1
|
||||
|
@ -548,14 +544,14 @@ class SRGANModel(BaseModel):
|
|||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
if 'pixgan' in self.opt['train']['gan_type']:
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||
|
@ -599,15 +595,15 @@ class SRGANModel(BaseModel):
|
|||
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||
l_d_real += l_d_fea_real
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||
l_d_fake += l_d_fea_fake
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
|
||||
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
||||
pdr = pdr / torch.max(pdr)
|
||||
|
@ -643,7 +639,6 @@ class SRGANModel(BaseModel):
|
|||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
||||
var_ref_skips.append(var_ref.detach())
|
||||
self.fake_H.append(fake_H.detach())
|
||||
self.optimizer_D.step()
|
||||
|
@ -657,20 +652,19 @@ class SRGANModel(BaseModel):
|
|||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
for var_ref, fake_H in zip(self.var_ref_skips, self.fake_H):
|
||||
fake_H_grad = self.get_grad_nopadding(fake_H)
|
||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
|
||||
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
|
||||
elif self.opt['train']['gan_type'] == 'pixgan':
|
||||
real = torch.ones_like(pred_d_real_grad)
|
||||
fake = torch.zeros_like(pred_d_fake_grad)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||
|
@ -932,4 +926,5 @@ class SRGANModel(BaseModel):
|
|||
def save(self, iter_step):
|
||||
self.save_network(self.netG, 'G', iter_step)
|
||||
self.save_network(self.netD, 'D', iter_step)
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||
if self.spsr_enabled:
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||
|
|
Loading…
Reference in New Issue
Block a user