forked from mrq/DL-Art-School
Add capability to have old discriminators serve as feature networks
This commit is contained in:
parent
6b45b35447
commit
d5fa059594
|
@ -106,6 +106,12 @@ class SRGANModel(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.netF = DataParallel(self.netF)
|
self.netF = DataParallel(self.netF)
|
||||||
|
|
||||||
|
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
|
||||||
|
self.fixed_disc_nets = []
|
||||||
|
if 'fixed_discriminators' in opt.keys():
|
||||||
|
for opt_fdisc in opt['fixed_discriminators'].keys():
|
||||||
|
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device))
|
||||||
|
|
||||||
# GD gan loss
|
# GD gan loss
|
||||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||||
self.l_gan_w = train_opt['gan_weight']
|
self.l_gan_w = train_opt['gan_weight']
|
||||||
|
@ -338,7 +344,15 @@ class SRGANModel(BaseModel):
|
||||||
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
||||||
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
||||||
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||||
# it should target this value.
|
# it should target this
|
||||||
|
|
||||||
|
l_g_fix_disc = 0
|
||||||
|
for fixed_disc in self.fixed_disc_nets:
|
||||||
|
weight = fixed_disc.fdisc_weight
|
||||||
|
real_fea = fixed_disc(pix).detach()
|
||||||
|
fake_fea = fixed_disc(fea_GenOut)
|
||||||
|
l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)
|
||||||
|
l_g_total += l_g_fix_disc
|
||||||
|
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||||
|
|
|
@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||||
import models.archs.SRG1_arch as srg1
|
import models.archs.SRG1_arch as srg1
|
||||||
import models.archs.ProgressiveSrg_arch as psrg
|
import models.archs.ProgressiveSrg_arch as psrg
|
||||||
import functools
|
import functools
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
def define_G(opt, net_key='network_G'):
|
def define_G(opt, net_key='network_G'):
|
||||||
|
@ -113,10 +114,7 @@ def define_G(opt, net_key='network_G'):
|
||||||
return netG
|
return netG
|
||||||
|
|
||||||
|
|
||||||
# Discriminator
|
def define_D_net(opt_net, img_sz=None):
|
||||||
def define_D(opt):
|
|
||||||
img_sz = opt['datasets']['train']['target_size']
|
|
||||||
opt_net = opt['network_D']
|
|
||||||
which_model = opt_net['which_model_D']
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
if which_model == 'discriminator_vgg_128':
|
if which_model == 'discriminator_vgg_128':
|
||||||
|
@ -140,6 +138,32 @@ def define_D(opt):
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
||||||
|
# Discriminator
|
||||||
|
def define_D(opt):
|
||||||
|
img_sz = opt['datasets']['train']['target_size']
|
||||||
|
opt_net = opt['network_D']
|
||||||
|
return define_D_net(opt_net, img_sz)
|
||||||
|
|
||||||
|
def define_fixed_D(opt):
|
||||||
|
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
||||||
|
net = define_D_net(opt)
|
||||||
|
|
||||||
|
# Load the model parameters:
|
||||||
|
load_net = torch.load(opt['pretrained_path'])
|
||||||
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
|
for k, v in load_net.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
load_net_clean[k[7:]] = v
|
||||||
|
else:
|
||||||
|
load_net_clean[k] = v
|
||||||
|
net.load_state_dict(load_net_clean)
|
||||||
|
|
||||||
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||||
|
net.eval()
|
||||||
|
for k, v in net.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
net.fdisc_weight = opt['weight']
|
||||||
|
|
||||||
|
|
||||||
# Define network used for perceptual loss
|
# Define network used for perceptual loss
|
||||||
def define_F(opt, use_bn=False, for_training=False):
|
def define_F(opt, use_bn=False, for_training=False):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user