Apply fixes to grad discriminator
This commit is contained in:
parent
96d66f51c5
commit
6d25bcd5df
|
@ -395,7 +395,6 @@ class SRGANModel(BaseModel):
|
||||||
using_gan_img = False
|
using_gan_img = False
|
||||||
# Get image gradients for later use.
|
# Get image gradients for later use.
|
||||||
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
|
fake_H_grad = self.get_grad_nopadding(fake_GenOut)
|
||||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
|
||||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||||
self.spsr_grad_GenOut.append(grad_LR)
|
self.spsr_grad_GenOut.append(grad_LR)
|
||||||
else:
|
else:
|
||||||
|
@ -477,10 +476,7 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
|
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
|
||||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||||
|
|
||||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
|
||||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
|
||||||
l_g_total += l_g_gan_grad
|
l_g_total += l_g_gan_grad
|
||||||
|
|
||||||
# Scale the loss down by the batch factor.
|
# Scale the loss down by the batch factor.
|
||||||
|
@ -508,7 +504,6 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
noise = torch.randn_like(var_ref) * noise_theta
|
noise = torch.randn_like(var_ref) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
|
||||||
real_disc_images = []
|
real_disc_images = []
|
||||||
fake_disc_images = []
|
fake_disc_images = []
|
||||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||||
|
@ -533,6 +528,7 @@ class SRGANModel(BaseModel):
|
||||||
fake_H = fake_H + noise
|
fake_H = fake_H + noise
|
||||||
l_d_fea_real = 0
|
l_d_fea_real = 0
|
||||||
l_d_fea_fake = 0
|
l_d_fea_fake = 0
|
||||||
|
self.optimizer_D.zero_grad()
|
||||||
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
||||||
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
|
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
|
||||||
disc_fea_scale = .1
|
disc_fea_scale = .1
|
||||||
|
@ -548,14 +544,14 @@ class SRGANModel(BaseModel):
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
||||||
l_d_real_scaled.backward()
|
|
||||||
# fake
|
# fake
|
||||||
pred_d_fake = self.netD(fake_H)
|
pred_d_fake = self.netD(fake_H)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
||||||
l_d_fake_scaled.backward()
|
l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
|
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||||
|
l_d_total_scaled.backward()
|
||||||
if 'pixgan' in self.opt['train']['gan_type']:
|
if 'pixgan' in self.opt['train']['gan_type']:
|
||||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||||
|
@ -599,15 +595,15 @@ class SRGANModel(BaseModel):
|
||||||
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
||||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||||
l_d_real += l_d_fea_real
|
l_d_real += l_d_fea_real
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
|
||||||
l_d_real_scaled.backward()
|
|
||||||
# fake
|
# fake
|
||||||
pred_d_fake = self.netD(fake_H)
|
pred_d_fake = self.netD(fake_H)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
||||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||||
l_d_fake += l_d_fea_fake
|
l_d_fake += l_d_fea_fake
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
|
||||||
l_d_fake_scaled.backward()
|
l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
|
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||||
|
l_d_total_scaled.backward()
|
||||||
|
|
||||||
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
pdr = pred_d_real.detach() + torch.abs(torch.min(pred_d_real))
|
||||||
pdr = pdr / torch.max(pdr)
|
pdr = pdr / torch.max(pdr)
|
||||||
|
@ -643,7 +639,6 @@ class SRGANModel(BaseModel):
|
||||||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
|
||||||
var_ref_skips.append(var_ref.detach())
|
var_ref_skips.append(var_ref.detach())
|
||||||
self.fake_H.append(fake_H.detach())
|
self.fake_H.append(fake_H.detach())
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
@ -657,20 +652,19 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netD_grad.parameters():
|
for p in self.netD_grad.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
self.optimizer_D_grad.zero_grad()
|
self.optimizer_D_grad.zero_grad()
|
||||||
|
for var_ref, fake_H in zip(self.var_ref_skips, self.fake_H):
|
||||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
|
||||||
fake_H_grad = self.get_grad_nopadding(fake_H)
|
fake_H_grad = self.get_grad_nopadding(fake_H)
|
||||||
var_ref_grad = self.get_grad_nopadding(var_ref)
|
var_ref_grad = self.get_grad_nopadding(var_ref)
|
||||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
l_d_real_grad = self.cri_gan(pred_d_real_grad, True) / self.mega_batch_factor
|
||||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) / self.mega_batch_factor
|
||||||
elif self.opt['train']['gan_type'] == 'pixgan':
|
elif self.opt['train']['gan_type'] == 'pixgan':
|
||||||
real = torch.ones_like(pred_d_real_grad)
|
real = torch.ones_like(pred_d_real_grad)
|
||||||
fake = torch.zeros_like(pred_d_fake_grad)
|
fake = torch.zeros_like(pred_d_fake_grad)
|
||||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real)
|
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, real)
|
||||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake)
|
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad, fake)
|
||||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||||
l_d_total_grad /= self.mega_batch_factor
|
l_d_total_grad /= self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||||
|
@ -932,4 +926,5 @@ class SRGANModel(BaseModel):
|
||||||
def save(self, iter_step):
|
def save(self, iter_step):
|
||||||
self.save_network(self.netG, 'G', iter_step)
|
self.save_network(self.netG, 'G', iter_step)
|
||||||
self.save_network(self.netD, 'D', iter_step)
|
self.save_network(self.netD, 'D', iter_step)
|
||||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
if self.spsr_enabled:
|
||||||
|
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user