Only train discriminator/gan losses when gan_w > 0

This commit is contained in:
James Betker 2020-06-01 15:09:10 -06:00
parent 1eb9c5a059
commit a38dd62489

View File

@ -239,16 +239,17 @@ class SRGANModel(BaseModel):
if step % self.l_fea_w_decay_steps == 0: if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan': if self.l_gan_w > 0:
pred_g_fake = self.netD(fake_GenOut) if self.opt['train']['gan_type'] == 'gan':
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) pred_g_fake = self.netD(fake_GenOut)
elif self.opt['train']['gan_type'] == 'ragan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
pred_d_real = self.netD(var_ref).detach() elif self.opt['train']['gan_type'] == 'ragan':
pred_g_fake = self.netD(fake_GenOut) pred_d_real = self.netD(var_ref).detach()
l_g_gan = self.l_gan_w * ( pred_g_fake = self.netD(fake_GenOut)
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + l_g_gan = self.l_gan_w * (
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
l_g_total += l_g_gan self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan
# Scale the loss down by the batch factor. # Scale the loss down by the batch factor.
l_g_total = l_g_total / self.mega_batch_factor l_g_total = l_g_total / self.mega_batch_factor
@ -258,51 +259,52 @@ class SRGANModel(BaseModel):
self.optimizer_G.step() self.optimizer_G.step()
# D # D
for p in self.netD.parameters(): if self.l_gan_w > 0:
p.requires_grad = True for p in self.netD.parameters():
p.requires_grad = True
noise = torch.randn_like(var_ref[0]) * noise_theta noise = torch.randn_like(var_ref[0]) * noise_theta
noise.to(self.device) noise.to(self.device)
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, var_ref_skips, self.pix):
# Re-compute generator outputs (post-update). # Re-compute generator outputs (post-update).
with torch.no_grad(): with torch.no_grad():
fake_H = self.netG(var_L) fake_H = self.netG(var_L)
# The following line detaches all generator outputs that are not None. # The following line detaches all generator outputs that are not None.
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
# Apply noise to the inputs to slow discriminator convergence. # Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) + var_ref[1:] var_ref = (var_ref[0] + noise,) + var_ref[1:]
fake_H = (fake_H[0] + noise,) + fake_H[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:]
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake, False) / self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
# pred_d_real = self.netD(var_ref) # pred_d_real = self.netD(var_ref)
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G # pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward() # l_d_total.backward()
pred_d_fake = self.netD(fake_H).detach() pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
self.optimizer_D.step() self.optimizer_D.step()
# Log sample images from first microbatch. # Log sample images from first microbatch.
if step % 50 == 0: if step % 50 == 0: