Add alternative first block for PixShuffleRRDB
This commit is contained in:
parent
43b7fccc89
commit
5ca53e7786
|
@ -59,7 +59,6 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock):
|
||||||
return out, atts
|
return out, atts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
'''Residual in Residual Dense Block'''
|
'''Residual in Residual Dense Block'''
|
||||||
|
|
||||||
|
@ -75,7 +74,8 @@ class RRDB(nn.Module):
|
||||||
out = self.RDB3(out)
|
out = self.RDB3(out)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
class AttentiveRRDB(RRDB):
|
# RRDB block that uses switching on the individual RDB modules that compose it to increase learning diversity.
|
||||||
|
class SwitchedRRDB(RRDB):
|
||||||
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature)
|
||||||
|
@ -118,12 +118,16 @@ class AttentiveRRDB(RRDB):
|
||||||
|
|
||||||
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
|
# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end.
|
||||||
class RRDBTrunk(nn.Module):
|
class RRDBTrunk(nn.Module):
|
||||||
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None):
|
def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None, conv_first_block=None):
|
||||||
super(RRDBTrunk, self).__init__()
|
super(RRDBTrunk, self).__init__()
|
||||||
if rrdb_block_f is None:
|
if rrdb_block_f is None:
|
||||||
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
|
rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc)
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
|
if conv_first_block is None:
|
||||||
|
self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True)
|
||||||
|
else:
|
||||||
|
self.conv_first = conv_first_block
|
||||||
|
|
||||||
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
|
self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True)
|
||||||
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
|
self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
@ -168,6 +172,7 @@ class RRDBBase(nn.Module):
|
||||||
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
|
val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j)))
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||||
class RRDBNet(RRDBBase):
|
class RRDBNet(RRDBBase):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1,
|
||||||
|
@ -290,6 +295,22 @@ class AssistedRRDBNet(nn.Module):
|
||||||
return (out,)
|
return (out,)
|
||||||
|
|
||||||
|
|
||||||
|
class PixShuffleInitialConv(nn.Module):
|
||||||
|
def __init__(self, reduction_factor, nf_out):
|
||||||
|
super(PixShuffleInitialConv, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(3 * (reduction_factor ** 2), nf_out, 1)
|
||||||
|
self.r = reduction_factor
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
(b, f, w, h) = x.shape
|
||||||
|
# This module can only be applied to input images (with 3 channels)
|
||||||
|
assert f == 3
|
||||||
|
# Perform a "reverse-pixel-shuffle", reducing the image size and increasing filter count by self.r**2
|
||||||
|
x = x.contiguous().view(b, 3, w // self.r, self.r, h // self.r, self.r)
|
||||||
|
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, 3 * (self.r ** 2), w // self.r, h // self.r)
|
||||||
|
# Apply the conv to bring the filter account to the desired size.
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
# This class uses a RRDBTrunk to perform processing on an image, then upsamples it.
|
||||||
class PixShuffleRRDB(RRDBBase):
|
class PixShuffleRRDB(RRDBBase):
|
||||||
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
|
def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None):
|
||||||
|
@ -299,7 +320,7 @@ class PixShuffleRRDB(RRDBBase):
|
||||||
assert nf % 16 == 0
|
assert nf % 16 == 0
|
||||||
|
|
||||||
# Trunk - does actual processing.
|
# Trunk - does actual processing.
|
||||||
self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f)
|
self.trunk = RRDBTrunk(3, nf, nb, gc, 1, rrdb_block_f, PixShuffleInitialConv(4, nf))
|
||||||
self.trunks = [self.trunk]
|
self.trunks = [self.trunk]
|
||||||
|
|
||||||
# Upsampling
|
# Upsampling
|
||||||
|
|
|
@ -35,23 +35,23 @@ def define_G(opt, net_key='network_G'):
|
||||||
elif which_model == 'AttentiveRRDBNet':
|
elif which_model == 'AttentiveRRDBNet':
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
||||||
rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
init_temperature=opt_net['temperature'],
|
init_temperature=opt_net['temperature'],
|
||||||
final_temperature_step=opt_net['temperature_final_step']))
|
final_temperature_step=opt_net['temperature_final_step']))
|
||||||
elif which_model == 'MultiRRDBNet':
|
elif which_model == 'MultiRRDBNet':
|
||||||
block_f = None
|
block_f = None
|
||||||
if opt_net['attention']:
|
if opt_net['attention']:
|
||||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
init_temperature=opt_net['temperature'],
|
init_temperature=opt_net['temperature'],
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
final_temperature_step=opt_net['temperature_final_step'])
|
||||||
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
|
netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'],
|
||||||
hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f)
|
hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f)
|
||||||
elif which_model == 'PixRRDBNet':
|
elif which_model == 'PixRRDBNet':
|
||||||
block_f = None
|
block_f = None
|
||||||
if opt_net['attention']:
|
if opt_net['attention']:
|
||||||
block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
||||||
init_temperature=opt_net['temperature'],
|
init_temperature=opt_net['temperature'],
|
||||||
final_temperature_step=opt_net['temperature_final_step'])
|
final_temperature_step=opt_net['temperature_final_step'])
|
||||||
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
||||||
elif which_model == 'ResGen':
|
elif which_model == 'ResGen':
|
||||||
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user