From 6e086d0c20d5547322a648a27fb9695947c51441 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 15:07:10 -0600 Subject: [PATCH] Fix fixed_disc --- codes/models/SRGAN_model.py | 3 ++- codes/models/networks.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 7e6d136f..786181e4 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -110,7 +110,7 @@ class SRGANModel(BaseModel): self.fixed_disc_nets = [] if 'fixed_discriminators' in opt.keys(): for opt_fdisc in opt['fixed_discriminators'].keys(): - self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device)) + self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) @@ -568,6 +568,7 @@ class SRGANModel(BaseModel): if self.cri_fea and l_g_fea_log is not None: self.add_log_entry('feature_weight', fea_w) self.add_log_entry('l_g_fea', l_g_fea_log.item()) + self.add_log_entry('l_g_fix_disc', l_g_fix_disc.item()) if self.l_gan_w > 0: self.add_log_entry('l_g_gan', l_g_gan_log.item()) self.add_log_entry('l_g_total', l_g_total_log.item()) diff --git a/codes/models/networks.py b/codes/models/networks.py index 649436b3..f8df741b 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -164,6 +164,8 @@ def define_fixed_D(opt): v.requires_grad = False net.fdisc_weight = opt['weight'] + return net + # Define network used for perceptual loss def define_F(opt, use_bn=False, for_training=False):