Add scaling to rrdb
This commit is contained in:
parent
607ff3c67c
commit
7d38381d46
|
@ -143,10 +143,12 @@ class RRDBNet(nn.Module):
|
|||
num_blocks=23,
|
||||
growth_channels=32,
|
||||
body_block=RRDB,
|
||||
blocks_per_checkpoint=4):
|
||||
blocks_per_checkpoint=4,
|
||||
scale=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.scale = scale
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
||||
self.body = make_layer(
|
||||
body_block,
|
||||
|
@ -184,8 +186,11 @@ class RRDBNet(nn.Module):
|
|||
# upsample
|
||||
feat = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
feat = self.lrelu(
|
||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
feat = self.lrelu(
|
||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
elif which_model == 'RRDBNetBypass':
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'])
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'])
|
||||
elif which_model == 'rcan':
|
||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||
opt_net['rgb_range'] = 255
|
||||
|
|
Loading…
Reference in New Issue
Block a user