From af1968f9e54d4343e6554a596b0ef6f80e9c14ba Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 19 May 2020 09:41:16 -0600
Subject: [PATCH] Allow passthrough discriminator to have passthrough disabled
 from config

---
 .../models/archs/DiscriminatorResnet_arch_passthrough.py | 9 ++++++++-
 codes/models/networks.py                                 | 3 ++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py
index 462261f7..236b1647 100644
--- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py
+++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py
@@ -107,11 +107,13 @@ class FixupBottleneck(nn.Module):
 
 class FixupResNet(nn.Module):
 
-    def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False):
+    def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, number_skips=2, use_bn=False,
+                 disable_passthrough=False):
         super(FixupResNet, self).__init__()
         self.num_layers = sum(layers)
         self.inplanes = 3
         self.number_skips = number_skips
+        self.disable_passthrough = disable_passthrough
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
         self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
         if number_skips > 0:
@@ -163,6 +165,11 @@ class FixupResNet(nn.Module):
             # Or just a tuple with only the high res input (this assumes number_skips was set right).
             x = x[0]
 
+        if self.disable_passthrough:
+            if self.number_skips > 0:
+                med_skip = torch.zeros_like(med_skip)
+            if self.number_skips > 1:
+                lo_skip = torch.zeros_like(lo_skip)
         x = self.layer0(x)
         if self.number_skips > 0:
             x = torch.cat([x, med_skip], dim=1)
diff --git a/codes/models/networks.py b/codes/models/networks.py
index 446322bc..98c17b46 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -75,7 +75,8 @@ def define_D(opt):
         netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
     elif which_model == 'discriminator_resnet_passthrough':
         netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
-                                                                   number_skips=opt_net['number_skips'], use_bn=True)
+                                                                   number_skips=opt_net['number_skips'], use_bn=True,
+                                                                   disable_passthrough=opt_net['disable_passthrough'])
     else:
         raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
     return netD