Add alternative first block for PixShuffleRRDB
This commit is contained in:
parent
43b7fccc89
commit
5ca53e7786
|
@ -59,7 +59,6 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
|||
return out, atts
|
||||
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
|
@ -75,7 +74,8 @@ class RRDB(nn.Module):
|
|||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
class AttentiveRRDB(RRDB):
|
||||
# RRDB block that uses switching on the individual RDB modules that compose it to increase learning diversity.
|
||||
class SwitchedRRDB(RRDB):
|
||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||
|
@ -118,12 +118,16 @@ class AttentiveRRDB(RRDB):
|
|||
|
||||
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
|
||||
class RRDBTrunk(nn.Module):
|
||||
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None):
|
||||
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None, conv_first_block=None):
|
||||
super(RRDBTrunk, self).__init__()
|
||||
if rrdb_block_f is None:
|
||||
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
|
||||
|
||||
if conv_first_block is None:
|
||||
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
|
||||
else:
|
||||
self.conv_first = conv_first_block
|
||||
|
||||
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
|
||||
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
|
||||
|
||||
|
@ -168,6 +172,7 @@ class RRDBBase(nn.Module):
|
|||
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
|
||||
return val
|
||||
|
||||
|
||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||
class RRDBNet(RRDBBase):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||
|
@ -290,6 +295,22 @@ class AssistedRRDBNet(nn.Module):
|
|||
return (out,)
|
||||
|
||||
|
||||
class PixShuffleInitialConv(nn.Module):
|
||||
def __init__(self, reduction_factor, nf_out):
|
||||
super(PixShuffleInitialConv, self).__init__()
|
||||
self.conv = nn.Conv2d(3 * (reduction_factor ** 2), nf_out, 1)
|
||||
self.r = reduction_factor
|
||||
|
||||
def forward(self, x):
|
||||
(b, f, w, h) = x.shape
|
||||
# This module can only be applied to input images (with 3 channels)
|
||||
assert f == 3
|
||||
# Perform a "reverse-pixel-shuffle", reducing the image size and increasing filter count by self.r**2
|
||||
x = x.contiguous().view(b, 3, w // self.r, self.r, h // self.r, self.r)
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, 3 * (self.r ** 2), w // self.r, h // self.r)
|
||||
# Apply the conv to bring the filter account to the desired size.
|
||||
return self.conv(x)
|
||||
|
||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||
class PixShuffleRRDB(RRDBBase):
|
||||
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
|
||||
|
@ -299,7 +320,7 @@ class PixShuffleRRDB(RRDBBase):
|
|||
assert nf % 16 == 0
|
||||
|
||||
# Trunk - does actual processing.
|
||||
self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f)
|
||||
self.trunk = RRDBTrunk(3, nf, nb, gc, 1, rrdb_block_f, PixShuffleInitialConv(4, nf))
|
||||
self.trunks = [self.trunk]
|
||||
|
||||
# Upsampling
|
||||
|
|
|
@ -35,13 +35,13 @@ def define_G(opt, net_key='network_G'):
|
|||
elif which_model == 'AttentiveRRDBNet':
|
||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step']))
|
||||
elif which_model == 'MultiRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
|
||||
|
@ -49,7 +49,7 @@ def define_G(opt, net_key='network_G'):
|
|||
elif which_model == 'PixRRDBNet':
|
||||
block_f = None
|
||||
if opt_net['attention']:
|
||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||
init_temperature=opt_net['temperature'],
|
||||
final_temperature_step=opt_net['temperature_final_step'])
|
||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||
|
|
Loading…
Reference in New Issue
Block a user