Only train discriminator/gan losses when gan_w > 0

This commit is contained in:
James Betker 2020-06-01 15:09:10 -06:00
parent 1eb9c5a059
commit a38dd62489

View File

@ -239,6 +239,7 @@ class SRGANModel(BaseModel):
if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
@ -258,6 +259,7 @@ class SRGANModel(BaseModel):
self.optimizer_G.step()
# D
if self.l_gan_w > 0:
for p in self.netD.parameters():
p.requires_grad = True