diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index f0517df4..1eea2624 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -59,7 +59,6 @@ class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock): return out, atts - class RRDB(nn.Module): '''Residual in Residual Dense Block''' @@ -75,7 +74,8 @@ class RRDB(nn.Module): out = self.RDB3(out) return out * 0.2 + x -class AttentiveRRDB(RRDB): +# RRDB block that uses switching on the individual RDB modules that compose it to increase learning diversity. +class SwitchedRRDB(RRDB): def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1): super(RRDB, self).__init__() self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, init_temperature=init_temperature) @@ -118,12 +118,16 @@ class AttentiveRRDB(RRDB): # This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end. class RRDBTrunk(nn.Module): - def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None): + def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None, conv_first_block=None): super(RRDBTrunk, self).__init__() if rrdb_block_f is None: rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc) - self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True) + if conv_first_block is None: + self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True) + else: + self.conv_first = conv_first_block + self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True) self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True) @@ -168,6 +172,7 @@ class RRDBBase(nn.Module): val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j))) return val + # This class uses a RRDBTrunk to perform processing on an image, then upsamples it. class RRDBNet(RRDBBase): def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1, @@ -290,6 +295,22 @@ class AssistedRRDBNet(nn.Module): return (out,) +class PixShuffleInitialConv(nn.Module): + def __init__(self, reduction_factor, nf_out): + super(PixShuffleInitialConv, self).__init__() + self.conv = nn.Conv2d(3 * (reduction_factor ** 2), nf_out, 1) + self.r = reduction_factor + + def forward(self, x): + (b, f, w, h) = x.shape + # This module can only be applied to input images (with 3 channels) + assert f == 3 + # Perform a "reverse-pixel-shuffle", reducing the image size and increasing filter count by self.r**2 + x = x.contiguous().view(b, 3, w // self.r, self.r, h // self.r, self.r) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, 3 * (self.r ** 2), w // self.r, h // self.r) + # Apply the conv to bring the filter account to the desired size. + return self.conv(x) + # This class uses a RRDBTrunk to perform processing on an image, then upsamples it. class PixShuffleRRDB(RRDBBase): def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None): @@ -299,7 +320,7 @@ class PixShuffleRRDB(RRDBBase): assert nf % 16 == 0 # Trunk - does actual processing. - self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f) + self.trunk = RRDBTrunk(3, nf, nb, gc, 1, rrdb_block_f, PixShuffleInitialConv(4, nf)) self.trunks = [self.trunk] # Upsampling diff --git a/codes/models/networks.py b/codes/models/networks.py index 702a989a..89bfc384 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -35,23 +35,23 @@ def define_G(opt, net_key='network_G'): elif which_model == 'AttentiveRRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, - rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], + rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])) elif which_model == 'MultiRRDBNet': block_f = None if opt_net['attention']: - block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step']) + block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step']) netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'], hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f) elif which_model == 'PixRRDBNet': block_f = None if opt_net['attention']: - block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step']) + block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step']) netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],