Only train discriminator/gan losses when gan_w > 0
This commit is contained in:
parent
1eb9c5a059
commit
a38dd62489
|
@ -239,16 +239,17 @@ class SRGANModel(BaseModel):
|
||||||
if step % self.l_fea_w_decay_steps == 0:
|
if step % self.l_fea_w_decay_steps == 0:
|
||||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||||
|
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.l_gan_w > 0:
|
||||||
pred_g_fake = self.netD(fake_GenOut)
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
pred_g_fake = self.netD(fake_GenOut)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
pred_d_real = self.netD(var_ref).detach()
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
pred_g_fake = self.netD(fake_GenOut)
|
pred_d_real = self.netD(var_ref).detach()
|
||||||
l_g_gan = self.l_gan_w * (
|
pred_g_fake = self.netD(fake_GenOut)
|
||||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
l_g_gan = self.l_gan_w * (
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
l_g_total += l_g_gan
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
|
||||||
# Scale the loss down by the batch factor.
|
# Scale the loss down by the batch factor.
|
||||||
l_g_total = l_g_total / self.mega_batch_factor
|
l_g_total = l_g_total / self.mega_batch_factor
|
||||||
|
@ -258,51 +259,52 @@ class SRGANModel(BaseModel):
|
||||||
self.optimizer_G.step()
|
self.optimizer_G.step()
|
||||||
|
|
||||||
# D
|
# D
|
||||||
for p in self.netD.parameters():
|
if self.l_gan_w > 0:
|
||||||
p.requires_grad = True
|
for p in self.netD.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
noise = torch.randn_like(var_ref[0]) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
|
||||||
# Re-compute generator outputs (post-update).
|
# Re-compute generator outputs (post-update).
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fake_H = self.netG(var_L)
|
fake_H = self.netG(var_L)
|
||||||
# The following line detaches all generator outputs that are not None.
|
# The following line detaches all generator outputs that are not None.
|
||||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
||||||
|
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# real
|
# real
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||||
l_d_real_scaled.backward()
|
l_d_real_scaled.backward()
|
||||||
# fake
|
# fake
|
||||||
pred_d_fake = self.netD(fake_H)
|
pred_d_fake = self.netD(fake_H)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
# pred_d_real = self.netD(var_ref)
|
# pred_d_real = self.netD(var_ref)
|
||||||
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
# l_d_total.backward()
|
# l_d_total.backward()
|
||||||
pred_d_fake = self.netD(fake_H).detach()
|
pred_d_fake = self.netD(fake_H).detach()
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||||
l_d_real_scaled.backward()
|
l_d_real_scaled.backward()
|
||||||
pred_d_fake = self.netD(fake_H)
|
pred_d_fake = self.netD(fake_H)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
|
||||||
# Log sample images from first microbatch.
|
# Log sample images from first microbatch.
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user