From d5fa05959412233f75230d00e25d615400fd44ab Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 14:59:54 -0600 Subject: [PATCH] Add capability to have old discriminators serve as feature networks --- codes/models/SRGAN_model.py | 16 +++++++++++++++- codes/models/networks.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 91f20bb1..7e6d136f 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -106,6 +106,12 @@ class SRGANModel(BaseModel): else: self.netF = DataParallel(self.netF) + # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses. + self.fixed_disc_nets = [] + if 'fixed_discriminators' in opt.keys(): + for opt_fdisc in opt['fixed_discriminators'].keys(): + self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device)) + # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] @@ -338,7 +344,15 @@ class SRGANModel(BaseModel): # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931 # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, - # it should target this value. + # it should target this + + l_g_fix_disc = 0 + for fixed_disc in self.fixed_disc_nets: + weight = fixed_disc.fdisc_weight + real_fea = fixed_disc(pix).detach() + fake_fea = fixed_disc(fea_GenOut) + l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea) + l_g_total += l_g_fix_disc if self.l_gan_w > 0: if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: diff --git a/codes/models/networks.py b/codes/models/networks.py index 889f9816..649436b3 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SRG1_arch as srg1 import models.archs.ProgressiveSrg_arch as psrg import functools +from collections import OrderedDict # Generator def define_G(opt, net_key='network_G'): @@ -113,10 +114,7 @@ def define_G(opt, net_key='network_G'): return netG -# Discriminator -def define_D(opt): - img_sz = opt['datasets']['train']['target_size'] - opt_net = opt['network_D'] +def define_D_net(opt_net, img_sz=None): which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': @@ -140,6 +138,32 @@ def define_D(opt): raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD +# Discriminator +def define_D(opt): + img_sz = opt['datasets']['train']['target_size'] + opt_net = opt['network_D'] + return define_D_net(opt_net, img_sz) + +def define_fixed_D(opt): + # Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added. + net = define_D_net(opt) + + # Load the model parameters: + load_net = torch.load(opt['pretrained_path']) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + net.load_state_dict(load_net_clean) + + # Put into eval mode, freeze the parameters and set the 'weight' field. + net.eval() + for k, v in net.named_parameters(): + v.requires_grad = False + net.fdisc_weight = opt['weight'] + # Define network used for perceptual loss def define_F(opt, use_bn=False, for_training=False):