Add RRDB Initial Stride

Allows downsampling immediately before processing, which reduces network complexity on
higher resolution images but keeps a higher filter count.
This commit is contained in:
James Betker 2020-06-02 10:47:15 -06:00
parent 76a38b6a53
commit dc17545083
2 changed files with 10 additions and 5 deletions

View File

@ -47,12 +47,12 @@ class RRDB(nn.Module):
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
@ -87,10 +87,11 @@ class AssistedRRDBNet(nn.Module):
# nb=number of additional blocks after the assistance layers.
# gc=growth channel inside of residual blocks
# scale=the number of times the output is doubled in size.
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2):
# initial_stride=the stride on the first conv. can be used to downsample the image for processing.
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1):
super(AssistedRRDBNet, self).__init__()
self.scale = scale
self.conv_first = nn.Conv2d(in_nc, nf, 7, 1, padding=3, bias=True)
self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True)
# Set-up the assist-net, which should do feature extraction for us.
self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True)

View File

@ -22,8 +22,12 @@ def define_G(opt, net_key='network_G'):
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
initial_stride = 1 if 'initial_stride' not in opt_net else opt_net['initial_stride']
assert initial_stride == 1 or initial_stride == 2
# Need to adjust the scale the generator sees by the stride since the stride causes a down-sample.
gen_scale = scale * initial_stride
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, initial_stride=initial_stride)
elif which_model == 'AssistedRRDBNet':
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)