Add alternative first block for PixShuffleRRDB

This commit is contained in:
James Betker 2020-06-10 21:45:24 -06:00
parent 43b7fccc89
commit 5ca53e7786
2 changed files with 33 additions and 12 deletions

View File

@ -59,7 +59,6 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
return out, atts
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
@ -75,7 +74,8 @@ class RRDB(nn.Module):
out = self.RDB3(out)
return out * 0.2 + x
class AttentiveRRDB(RRDB):
# RRDB block that uses switching on the individual RDB modules that compose it to increase learning diversity.
class SwitchedRRDB(RRDB):
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
super(RRDB, self).__init__()
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
@ -118,12 +118,16 @@ class AttentiveRRDB(RRDB):
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
class RRDBTrunk(nn.Module):
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None):
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None, conv_first_block=None):
super(RRDBTrunk, self).__init__()
if rrdb_block_f is None:
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
if conv_first_block is None:
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
else:
self.conv_first = conv_first_block
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
@ -168,6 +172,7 @@ class RRDBBase(nn.Module):
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
return val
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
class RRDBNet(RRDBBase):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
@ -290,6 +295,22 @@ class AssistedRRDBNet(nn.Module):
return (out,)
class PixShuffleInitialConv(nn.Module):
def __init__(self, reduction_factor, nf_out):
super(PixShuffleInitialConv, self).__init__()
self.conv = nn.Conv2d(3 * (reduction_factor ** 2), nf_out, 1)
self.r = reduction_factor
def forward(self, x):
(b, f, w, h) = x.shape
# This module can only be applied to input images (with 3 channels)
assert f == 3
# Perform a "reverse-pixel-shuffle", reducing the image size and increasing filter count by self.r**2
x = x.contiguous().view(b, 3, w // self.r, self.r, h // self.r, self.r)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, 3 * (self.r ** 2), w // self.r, h // self.r)
# Apply the conv to bring the filter account to the desired size.
return self.conv(x)
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
class PixShuffleRRDB(RRDBBase):
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
@ -299,7 +320,7 @@ class PixShuffleRRDB(RRDBBase):
assert nf % 16 == 0
# Trunk - does actual processing.
self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f)
self.trunk = RRDBTrunk(3, nf, nb, gc, 1, rrdb_block_f, PixShuffleInitialConv(4, nf))
self.trunks = [self.trunk]
# Upsampling

View File

@ -35,13 +35,13 @@ def define_G(opt, net_key='network_G'):
elif which_model == 'AttentiveRRDBNet':
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step']))
elif which_model == 'MultiRRDBNet':
block_f = None
if opt_net['attention']:
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
@ -49,7 +49,7 @@ def define_G(opt, net_key='network_G'):
elif which_model == 'PixRRDBNet':
block_f = None
if opt_net['attention']:
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
init_temperature=opt_net['temperature'],
final_temperature_step=opt_net['temperature_final_step'])
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)