Only train discriminator/gan losses when gan_w > 0
This commit is contained in:
parent
1eb9c5a059
commit
a38dd62489
|
@ -239,16 +239,17 @@ class SRGANModel(BaseModel):
|
|||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
l_g_total = l_g_total / self.mega_batch_factor
|
||||
|
@ -258,51 +259,52 @@ class SRGANModel(BaseModel):
|
|||
self.optimizer_G.step()
|
||||
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
if self.l_gan_w > 0:
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
fake_H = self.netG(var_L)
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||
noise.to(self.device)
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
fake_H = self.netG(var_L)
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
||||
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
# pred_d_real = self.netD(var_ref)
|
||||
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
self.optimizer_D.step()
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
# pred_d_real = self.netD(var_ref)
|
||||
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user