From 42a97de75697fd85a0bc8d26b405d1c98efa4ab1 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Wed, 11 Nov 2020 12:14:14 -0700
Subject: [PATCH] Convert PyramidRRDBDisc to RRDBDisc

Had numeric stability issues. This probably makes more sense anyways.
---
 codes/models/archs/discriminator_vgg_arch.py | 30 ++++++++++----------
 codes/models/networks.py                     |  4 +--
 2 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py
index 2e26504e..6caeb05c 100644
--- a/codes/models/archs/discriminator_vgg_arch.py
+++ b/codes/models/archs/discriminator_vgg_arch.py
@@ -6,7 +6,7 @@ from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvG
 import torch.nn.functional as F
 from models.archs.SwitchedResidualGenerator_arch import gather_2d
 from models.archs.pyramid_arch import Pyramid
-from utils.util import checkpoint
+from utils.util import checkpoint, sequential_checkpoint
 
 
 class Discriminator_VGG_128(nn.Module):
@@ -662,24 +662,24 @@ class SingleImageQualityEstimator(nn.Module):
         return fea
 
 
-class PyramidRRDBDiscriminator(nn.Module):
+class RRDBDiscriminator(nn.Module):
     def __init__(self, in_nc, nf, block=ConvGnLelu):
-        super(PyramidRRDBDiscriminator, self).__init__()
+        super(RRDBDiscriminator, self).__init__()
         self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True)
-        self.top_proc = nn.Sequential(*[RRDBWithBypass(nf),
-                                       RRDBWithBypass(nf)])
-        self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2,
-                               scale_per_level=1.5, norm=True, return_outlevels=False)
-        self.bottom_proc = nn.Sequential(*[RRDBWithBypass(nf),
-                                       RRDBWithBypass(nf),
-                                       ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
-                                       ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
-                                       ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
+        self.trunk = nn.ModuleList(*[RRDBWithBypass(nf),
+                                   RRDBWithBypass(nf),
+                                   RRDBWithBypass(nf),
+                                   RRDBWithBypass(nf),
+                                   RRDBWithBypass(nf)])
+
+        self.tail = nn.Sequential(*[
+            ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True),
+            ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True),
+            ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])
 
     def forward(self, x):
         fea = self.initial_conv(x)
-        fea = checkpoint(self.top_proc, fea)
-        fea = checkpoint(self.pyramid, fea)
-        fea = checkpoint(self.bottom_proc, fea)
+        fea = sequential_checkpoint(self.top_proc, 2, fea)
+        fea = checkpoint(self.tail, fea)
         return torch.mean(fea, dim=[1,2,3])
 
diff --git a/codes/models/networks.py b/codes/models/networks.py
index dfc01983..2a67904e 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -187,8 +187,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
         netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
     elif which_model == "psnr_approximator":
         netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
-    elif which_model == "pyramid_rrdb_disc":
-        netD = SRGAN_arch.PyramidRRDBDiscriminator(in_nc=3, nf=opt_net['nf'])
+    elif which_model == "rrdb_disc":
+        netD = SRGAN_arch.RRDBDiscriminator(in_nc=3, nf=opt_net['nf'])
     else:
         raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
     return netD