From d5fa05959412233f75230d00e25d615400fd44ab Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Fri, 31 Jul 2020 14:59:54 -0600
Subject: [PATCH] Add capability to have old discriminators serve as feature
 networks

---
 codes/models/SRGAN_model.py | 16 +++++++++++++++-
 codes/models/networks.py    | 32 ++++++++++++++++++++++++++++----
 2 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py
index 91f20bb1..7e6d136f 100644
--- a/codes/models/SRGAN_model.py
+++ b/codes/models/SRGAN_model.py
@@ -106,6 +106,12 @@ class SRGANModel(BaseModel):
                 else:
                     self.netF = DataParallel(self.netF)
 
+            # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
+            self.fixed_disc_nets = []
+            if 'fixed_discriminators' in opt.keys():
+                for opt_fdisc in opt['fixed_discriminators'].keys():
+                    self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminator'][opt_fdisc]).to(self.device))
+
             # GD gan loss
             self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
             self.l_gan_w = train_opt['gan_weight']
@@ -338,7 +344,15 @@ class SRGANModel(BaseModel):
                     # Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
                     # Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
                     # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
-                    # it should target this value.
+                    # it should target this
+
+                l_g_fix_disc = 0
+                for fixed_disc in self.fixed_disc_nets:
+                    weight = fixed_disc.fdisc_weight
+                    real_fea = fixed_disc(pix).detach()
+                    fake_fea = fixed_disc(fea_GenOut)
+                    l_g_fix_disc += weight * self.cri_fea(fake_fea, real_fea)
+                l_g_total += l_g_fix_disc
 
                 if self.l_gan_w > 0:
                     if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 889f9816..649436b3 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
 import models.archs.SRG1_arch as srg1
 import models.archs.ProgressiveSrg_arch as psrg
 import functools
+from collections import OrderedDict
 
 # Generator
 def define_G(opt, net_key='network_G'):
@@ -113,10 +114,7 @@ def define_G(opt, net_key='network_G'):
     return netG
 
 
-# Discriminator
-def define_D(opt):
-    img_sz = opt['datasets']['train']['target_size']
-    opt_net = opt['network_D']
+def define_D_net(opt_net, img_sz=None):
     which_model = opt_net['which_model_D']
 
     if which_model == 'discriminator_vgg_128':
@@ -140,6 +138,32 @@ def define_D(opt):
         raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
     return netD
 
+# Discriminator
+def define_D(opt):
+    img_sz = opt['datasets']['train']['target_size']
+    opt_net = opt['network_D']
+    return define_D_net(opt_net, img_sz)
+
+def define_fixed_D(opt):
+    # Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
+    net = define_D_net(opt)
+
+    # Load the model parameters:
+    load_net = torch.load(opt['pretrained_path'])
+    load_net_clean = OrderedDict()  # remove unnecessary 'module.'
+    for k, v in load_net.items():
+        if k.startswith('module.'):
+            load_net_clean[k[7:]] = v
+        else:
+            load_net_clean[k] = v
+    net.load_state_dict(load_net_clean)
+
+    # Put into eval mode, freeze the parameters and set the 'weight' field.
+    net.eval()
+    for k, v in net.named_parameters():
+        v.requires_grad = False
+    net.fdisc_weight = opt['weight']
+
 
 # Define network used for perceptual loss
 def define_F(opt, use_bn=False, for_training=False):