Add scaling to rrdb
This commit is contained in:
parent
607ff3c67c
commit
7d38381d46
|
@ -143,10 +143,12 @@ class RRDBNet(nn.Module):
|
||||||
num_blocks=23,
|
num_blocks=23,
|
||||||
growth_channels=32,
|
growth_channels=32,
|
||||||
body_block=RRDB,
|
body_block=RRDB,
|
||||||
blocks_per_checkpoint=4):
|
blocks_per_checkpoint=4,
|
||||||
|
scale=4):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||||
|
self.scale = scale
|
||||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
||||||
self.body = make_layer(
|
self.body = make_layer(
|
||||||
body_block,
|
body_block,
|
||||||
|
@ -184,8 +186,11 @@ class RRDBNet(nn.Module):
|
||||||
# upsample
|
# upsample
|
||||||
feat = self.lrelu(
|
feat = self.lrelu(
|
||||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||||
feat = self.lrelu(
|
if self.scale == 4:
|
||||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
feat = self.lrelu(
|
||||||
|
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||||
|
else:
|
||||||
|
feat = self.lrelu(self.conv_up2(feat))
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == 'RRDBNetBypass':
|
elif which_model == 'RRDBNetBypass':
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'])
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
||||||
elif which_model == 'rcan':
|
elif which_model == 'rcan':
|
||||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||||
opt_net['rgb_range'] = 255
|
opt_net['rgb_range'] = 255
|
||||||
|
|
Loading…
Reference in New Issue
Block a user