From dc17545083f52f16afe32ebdeb5e96d852caff6d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 2 Jun 2020 10:47:15 -0600 Subject: [PATCH] Add RRDB Initial Stride Allows downsampling immediately before processing, which reduces network complexity on higher resolution images but keeps a higher filter count. --- codes/models/archs/RRDBNet_arch.py | 9 +++++---- codes/models/networks.py | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 5a73c490..64873c4a 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -47,12 +47,12 @@ class RRDB(nn.Module): class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.scale = scale - self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True) + self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True) self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling @@ -87,10 +87,11 @@ class AssistedRRDBNet(nn.Module): # nb=number of additional blocks after the assistance layers. # gc=growth channel inside of residual blocks # scale=the number of times the output is doubled in size. - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2): + # initial_stride=the stride on the first conv. can be used to downsample the image for processing. + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1): super(AssistedRRDBNet, self).__init__() self.scale = scale - self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True) + self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True) # Set-up the assist-net, which should do feature extraction for us. self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True) diff --git a/codes/models/networks.py b/codes/models/networks.py index e6ec0f4a..137a3361 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -22,8 +22,12 @@ def define_G(opt, net_key='network_G'): nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': # RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB. + initial_stride = 1 if 'initial_stride' not in opt_net else opt_net['initial_stride'] + assert initial_stride == 1 or initial_stride == 2 + # Need to adjust the scale the generator sees by the stride since the stride causes a down-sample. + gen_scale = scale * initial_stride netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) + nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, initial_stride=initial_stride) elif which_model == 'AssistedRRDBNet': netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)