Apply fixes to grad discriminator

This commit is contained in:
James Betker 2020-08-04 10:25:13 -06:00
parent 96d66f51c5
commit 6d25bcd5df

View File

@ -395,7 +395,6 @@ class SRGANModel(BaseModel):
using_gan_img = False
# Get image gradients for later use.
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
var_ref_grad = self.get_grad_nopadding(var_ref)
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
self.spsr_grad_GenOut.append(grad_LR)
else:
@ -477,10 +476,7 @@ class SRGANModel(BaseModel):
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
pred_g_fake_grad = self.netD_grad(fake_H_grad)
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
l_g_total += l_g_gan_grad
# Scale the loss down by the batch factor.
@ -508,7 +504,6 @@ class SRGANModel(BaseModel):
noise = torch.randn_like(var_ref) * noise_theta
noise.to(self.device)
self.optimizer_D.zero_grad()
real_disc_images = []
fake_disc_images = []
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
@ -533,6 +528,7 @@ class SRGANModel(BaseModel):
fake_H = fake_H + noise
l_d_fea_real = 0
l_d_fea_fake = 0
self.optimizer_D.zero_grad()
if self.opt['train']['gan_type'] == 'pixgan_fea':
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
disc_fea_scale = .1
@ -548,14 +544,14 @@ class SRGANModel(BaseModel):
pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
l_d_total = (l_d_real + l_d_fake) / 2
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
l_d_total_scaled.backward()
if 'pixgan' in self.opt['train']['gan_type']:
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
@ -599,15 +595,15 @@ class SRGANModel(BaseModel):
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor
l_d_real += l_d_fea_real
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor
l_d_fake += l_d_fea_fake
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
l_d_total = (l_d_real + l_d_fake) / 2
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
l_d_total_scaled.backward()
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
pdr = pdr / torch.max(pdr)
@ -643,7 +639,6 @@ class SRGANModel(BaseModel):
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
_t = time()
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
var_ref_skips.append(var_ref.detach())
self.fake_H.append(fake_H.detach())
self.optimizer_D.step()
@ -657,20 +652,19 @@ class SRGANModel(BaseModel):
for p in self.netD_grad.parameters():
p.requires_grad = True
self.optimizer_D_grad.zero_grad()
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
for var_ref, fake_H in zip(self.var_ref_skips, self.fake_H):
fake_H_grad = self.get_grad_nopadding(fake_H)
var_ref_grad = self.get_grad_nopadding(var_ref)
pred_d_real_grad = self.netD_grad(var_ref_grad)
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
if self.opt['train']['gan_type'] == 'gan':
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
elif self.opt['train']['gan_type'] == 'pixgan':
real = torch.ones_like(pred_d_real_grad)
fake = torch.zeros_like(pred_d_fake_grad)
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake)
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad /= self.mega_batch_factor
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
@ -932,4 +926,5 @@ class SRGANModel(BaseModel):
def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step)
self.save_network(self.netD_grad, 'D_grad', iter_step)
if self.spsr_enabled:
self.save_network(self.netD_grad, 'D_grad', iter_step)