diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index e08a3353..965f0c2c 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -116,7 +116,7 @@ class RRDBWithBypass(nn.Module): out = self.rdb3(out) bypass = self.bypass(torch.cat([x, out], dim=1)) self.bypass_map = bypass.detach().clone() - # Emperically, we use 0.2 to scale the residual for better performance + # Empirically, we use 0.2 to scale the residual for better performance return out * 0.2 * bypass + x diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 15539355..2e26504e 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,8 +1,11 @@ import torch import torch.nn as nn + +from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu import torch.nn.functional as F from models.archs.SwitchedResidualGenerator_arch import gather_2d +from models.archs.pyramid_arch import Pyramid from utils.util import checkpoint @@ -78,6 +81,7 @@ class Discriminator_VGG_128(nn.Module): out = self.linear2(fea) return out + class Discriminator_VGG_128_GN(nn.Module): # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False): @@ -656,3 +660,26 @@ class SingleImageQualityEstimator(nn.Module): fea = self.lrelu(self.conv4_2(fea)) fea = self.sigmoid(self.conv4_3(fea)) return fea + + +class PyramidRRDBDiscriminator(nn.Module): + def __init__(self, in_nc, nf, block=ConvGnLelu): + super(PyramidRRDBDiscriminator, self).__init__() + self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True) + self.top_proc = nn.Sequential(*[RRDBWithBypass(nf), + RRDBWithBypass(nf)]) + self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2, + scale_per_level=1.5, norm=True, return_outlevels=False) + self.bottom_proc = nn.Sequential(*[RRDBWithBypass(nf), + RRDBWithBypass(nf), + ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True), + ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True), + ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)]) + + def forward(self, x): + fea = self.initial_conv(x) + fea = checkpoint(self.top_proc, fea) + fea = checkpoint(self.pyramid, fea) + fea = checkpoint(self.bottom_proc, fea) + return torch.mean(fea, dim=[1,2,3]) + diff --git a/codes/models/archs/pyramid_arch.py b/codes/models/archs/pyramid_arch.py index 9e0ce268..855ffa58 100644 --- a/codes/models/archs/pyramid_arch.py +++ b/codes/models/archs/pyramid_arch.py @@ -1,7 +1,7 @@ import torch from torch import nn -from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock +from models.archs.arch_util import ConvGnLelu, ExpansionBlock from models.flownet2.networks.resample2d_package.resample2d import Resample2d from utils.util import checkpoint import torch.nn.functional as F diff --git a/codes/models/networks.py b/codes/models/networks.py index e5a79536..dfc01983 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -187,6 +187,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == "psnr_approximator": netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) + elif which_model == "pyramid_rrdb_disc": + netD = SRGAN_arch.PyramidRRDBDiscriminator(in_nc=3, nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD