forked from mrq/DL-Art-School
Add capability to have old discriminators serve as feature networks
This commit is contained in:
parent
6b45b35447
commit
d5fa059594
|
@ -106,6 +106,12 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
self.netF = DataParallel(self.netF)
|
||||
|
||||
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
|
||||
self.fixed_disc_nets = []
|
||||
if 'fixed_discriminators' in opt.keys():
|
||||
for opt_fdisc in opt['fixed_discriminators'].keys():
|
||||
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device))
|
||||
|
||||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_w = train_opt['gan_weight']
|
||||
|
@ -338,7 +344,15 @@ class SRGANModel(BaseModel):
|
|||
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
||||
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
||||
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||
# it should target this value.
|
||||
# it should target this
|
||||
|
||||
l_g_fix_disc = 0
|
||||
for fixed_disc in self.fixed_disc_nets:
|
||||
weight = fixed_disc.fdisc_weight
|
||||
real_fea = fixed_disc(pix).detach()
|
||||
fake_fea = fixed_disc(fea_GenOut)
|
||||
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fix_disc
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
|
|
|
@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
|||
import models.archs.SRG1_arch as srg1
|
||||
import models.archs.ProgressiveSrg_arch as psrg
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
# Generator
|
||||
def define_G(opt, net_key='network_G'):
|
||||
|
@ -113,10 +114,7 @@ def define_G(opt, net_key='network_G'):
|
|||
return netG
|
||||
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
img_sz = opt['datasets']['train']['target_size']
|
||||
opt_net = opt['network_D']
|
||||
def define_D_net(opt_net, img_sz=None):
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
|
@ -140,6 +138,32 @@ def define_D(opt):
|
|||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
img_sz = opt['datasets']['train']['target_size']
|
||||
opt_net = opt['network_D']
|
||||
return define_D_net(opt_net, img_sz)
|
||||
|
||||
def define_fixed_D(opt):
|
||||
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
||||
net = define_D_net(opt)
|
||||
|
||||
# Load the model parameters:
|
||||
load_net = torch.load(opt['pretrained_path'])
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
net.load_state_dict(load_net_clean)
|
||||
|
||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||
net.eval()
|
||||
for k, v in net.named_parameters():
|
||||
v.requires_grad = False
|
||||
net.fdisc_weight = opt['weight']
|
||||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(opt, use_bn=False, for_training=False):
|
||||
|
|
Loading…
Reference in New Issue
Block a user