forked from mrq/DL-Art-School
Only train discriminator/gan losses when gan_w > 0
This commit is contained in:
parent
1eb9c5a059
commit
a38dd62489
|
@ -239,6 +239,7 @@ class SRGANModel(BaseModel):
|
||||||
if step % self.l_fea_w_decay_steps == 0:
|
if step % self.l_fea_w_decay_steps == 0:
|
||||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||||
|
|
||||||
|
if self.l_gan_w > 0:
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
pred_g_fake = self.netD(fake_GenOut)
|
pred_g_fake = self.netD(fake_GenOut)
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
|
@ -258,6 +259,7 @@ class SRGANModel(BaseModel):
|
||||||
self.optimizer_G.step()
|
self.optimizer_G.step()
|
||||||
|
|
||||||
# D
|
# D
|
||||||
|
if self.l_gan_w > 0:
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user