Add scaling to rrdb

This commit is contained in:
James Betker 2020-10-29 09:48:10 -06:00
parent 607ff3c67c
commit 7d38381d46
2 changed files with 9 additions and 4 deletions

View File

@ -143,10 +143,12 @@ class RRDBNet(nn.Module):
num_blocks=23, num_blocks=23,
growth_channels=32, growth_channels=32,
body_block=RRDB, body_block=RRDB,
blocks_per_checkpoint=4): blocks_per_checkpoint=4,
scale=4):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.body = make_layer( self.body = make_layer(
body_block, body_block,
@ -184,8 +186,11 @@ class RRDBNet(nn.Module):
# upsample # upsample
feat = self.lrelu( feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu( if self.scale == 4:
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
else:
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat))) out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out return out

View File

@ -42,7 +42,7 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == 'RRDBNetBypass': elif which_model == 'RRDBNetBypass':
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
blocks_per_checkpoint=opt_net['blocks_per_checkpoint']) blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
elif which_model == 'rcan': elif which_model == 'rcan':
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
opt_net['rgb_range'] = 255 opt_net['rgb_range'] = 255