Fix fixed_disc
This commit is contained in:
parent
d5fa059594
commit
6e086d0c20
|
@ -110,7 +110,7 @@ class SRGANModel(BaseModel):
|
|||
self.fixed_disc_nets = []
|
||||
if 'fixed_discriminators' in opt.keys():
|
||||
for opt_fdisc in opt['fixed_discriminators'].keys():
|
||||
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device))
|
||||
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device))
|
||||
|
||||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
|
@ -568,6 +568,7 @@ class SRGANModel(BaseModel):
|
|||
if self.cri_fea and l_g_fea_log is not None:
|
||||
self.add_log_entry('feature_weight', fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea_log.item())
|
||||
self.add_log_entry('l_g_fix_disc', l_g_fix_disc.item())
|
||||
if self.l_gan_w > 0:
|
||||
self.add_log_entry('l_g_gan', l_g_gan_log.item())
|
||||
self.add_log_entry('l_g_total', l_g_total_log.item())
|
||||
|
|
|
@ -164,6 +164,8 @@ def define_fixed_D(opt):
|
|||
v.requires_grad = False
|
||||
net.fdisc_weight = opt['weight']
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(opt, use_bn=False, for_training=False):
|
||||
|
|
Loading…
Reference in New Issue
Block a user