Only train discriminator/gan losses when gan_w > 0
This commit is contained in:
parent
1eb9c5a059
commit
a38dd62489
|
@ -239,6 +239,7 @@ class SRGANModel(BaseModel):
|
|||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
|
@ -258,6 +259,7 @@ class SRGANModel(BaseModel):
|
|||
self.optimizer_G.step()
|
||||
|
||||
# D
|
||||
if self.l_gan_w > 0:
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user