Add RRDB Initial Stride
Allows downsampling immediately before processing, which reduces network complexity on higher resolution images but keeps a higher filter count.
This commit is contained in:
parent
76a38b6a53
commit
dc17545083
|
@ -47,12 +47,12 @@ class RRDB(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2):
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
|
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
|
||||||
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
|
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
#### upsampling
|
#### upsampling
|
||||||
|
@ -87,10 +87,11 @@ class AssistedRRDBNet(nn.Module):
|
||||||
# nb=number of additional blocks after the assistance layers.
|
# nb=number of additional blocks after the assistance layers.
|
||||||
# gc=growth channel inside of residual blocks
|
# gc=growth channel inside of residual blocks
|
||||||
# scale=the number of times the output is doubled in size.
|
# scale=the number of times the output is doubled in size.
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2):
|
# initial_stride=the stride on the first conv. can be used to downsample the image for processing.
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1):
|
||||||
super(AssistedRRDBNet, self).__init__()
|
super(AssistedRRDBNet, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
|
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
|
||||||
|
|
||||||
# Set-up the assist-net, which should do feature extraction for us.
|
# Set-up the assist-net, which should do feature extraction for us.
|
||||||
self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True)
|
self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True)
|
||||||
|
|
|
@ -22,8 +22,12 @@ def define_G(opt, net_key='network_G'):
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
elif which_model == 'RRDBNet':
|
elif which_model == 'RRDBNet':
|
||||||
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
||||||
|
initial_stride = 1 if 'initial_stride' not in opt_net else opt_net['initial_stride']
|
||||||
|
assert initial_stride == 1 or initial_stride == 2
|
||||||
|
# Need to adjust the scale the generator sees by the stride since the stride causes a down-sample.
|
||||||
|
gen_scale = scale * initial_stride
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, initial_stride=initial_stride)
|
||||||
elif which_model == 'AssistedRRDBNet':
|
elif which_model == 'AssistedRRDBNet':
|
||||||
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user