From 7d38381d4649cfea376f93277af546eff0633034 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 29 Oct 2020 09:48:10 -0600 Subject: [PATCH] Add scaling to rrdb --- codes/models/archs/RRDBNet_arch.py | 11 ++++++++--- codes/models/networks.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 6c4ac41c..acfd3ce9 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -143,10 +143,12 @@ class RRDBNet(nn.Module): num_blocks=23, growth_channels=32, body_block=RRDB, - blocks_per_checkpoint=4): + blocks_per_checkpoint=4, + scale=4): super(RRDBNet, self).__init__() self.num_blocks = num_blocks self.blocks_per_checkpoint = blocks_per_checkpoint + self.scale = scale self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) self.body = make_layer( body_block, @@ -184,8 +186,11 @@ class RRDBNet(nn.Module): # upsample feat = self.lrelu( self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) - feat = self.lrelu( - self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + if self.scale == 4: + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + else: + feat = self.lrelu(self.conv_up2(feat)) out = self.conv_last(self.lrelu(self.conv_hr(feat))) return out diff --git a/codes/models/networks.py b/codes/models/networks.py index cc2b6f31..132d6ed1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -42,7 +42,7 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'RRDBNetBypass': netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, - blocks_per_checkpoint=opt_net['blocks_per_checkpoint']) + blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale']) elif which_model == 'rcan': #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats opt_net['rgb_range'] = 255