diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 2e26504e..6caeb05c 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -6,7 +6,7 @@ from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvG import torch.nn.functional as F from models.archs.SwitchedResidualGenerator_arch import gather_2d from models.archs.pyramid_arch import Pyramid -from utils.util import checkpoint +from utils.util import checkpoint, sequential_checkpoint class Discriminator_VGG_128(nn.Module): @@ -662,24 +662,24 @@ class SingleImageQualityEstimator(nn.Module): return fea -class PyramidRRDBDiscriminator(nn.Module): +class RRDBDiscriminator(nn.Module): def __init__(self, in_nc, nf, block=ConvGnLelu): - super(PyramidRRDBDiscriminator, self).__init__() + super(RRDBDiscriminator, self).__init__() self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True) - self.top_proc = nn.Sequential(*[RRDBWithBypass(nf), - RRDBWithBypass(nf)]) - self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2, - scale_per_level=1.5, norm=True, return_outlevels=False) - self.bottom_proc = nn.Sequential(*[RRDBWithBypass(nf), - RRDBWithBypass(nf), - ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True), - ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True), - ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)]) + self.trunk = nn.ModuleList(*[RRDBWithBypass(nf), + RRDBWithBypass(nf), + RRDBWithBypass(nf), + RRDBWithBypass(nf), + RRDBWithBypass(nf)]) + + self.tail = nn.Sequential(*[ + ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True), + ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True), + ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)]) def forward(self, x): fea = self.initial_conv(x) - fea = checkpoint(self.top_proc, fea) - fea = checkpoint(self.pyramid, fea) - fea = checkpoint(self.bottom_proc, fea) + fea = sequential_checkpoint(self.top_proc, 2, fea) + fea = checkpoint(self.tail, fea) return torch.mean(fea, dim=[1,2,3]) diff --git a/codes/models/networks.py b/codes/models/networks.py index dfc01983..2a67904e 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -187,8 +187,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == "psnr_approximator": netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) - elif which_model == "pyramid_rrdb_disc": - netD = SRGAN_arch.PyramidRRDBDiscriminator(in_nc=3, nf=opt_net['nf']) + elif which_model == "rrdb_disc": + netD = SRGAN_arch.RRDBDiscriminator(in_nc=3, nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD