Add capability to have old discriminators serve as feature networks

This commit is contained in:
James Betker 2020-07-31 14:59:54 -06:00
parent 6b45b35447
commit d5fa059594
2 changed files with 43 additions and 5 deletions

View File

@ -106,6 +106,12 @@ class SRGANModel(BaseModel):
else:
self.netF = DataParallel(self.netF)
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
self.fixed_disc_nets = []
if 'fixed_discriminators' in opt.keys():
for opt_fdisc in opt['fixed_discriminators'].keys():
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device))
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_w = train_opt['gan_weight']
@ -338,7 +344,15 @@ class SRGANModel(BaseModel):
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this value.
# it should target this
l_g_fix_disc = 0
for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.fdisc_weight
real_fea = fixed_disc(pix).detach()
fake_fea = fixed_disc(fea_GenOut)
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fix_disc
if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:

View File

@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SRG1_arch as srg1
import models.archs.ProgressiveSrg_arch as psrg
import functools
from collections import OrderedDict
# Generator
def define_G(opt, net_key='network_G'):
@ -113,10 +114,7 @@ def define_G(opt, net_key='network_G'):
return netG
# Discriminator
def define_D(opt):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
def define_D_net(opt_net, img_sz=None):
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
@ -140,6 +138,32 @@ def define_D(opt):
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
net.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
net.eval()
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
# Define network used for perceptual loss
def define_F(opt, use_bn=False, for_training=False):