diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index af0139ac..f5ca03e3 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -5,7 +5,6 @@ import torch.nn.functional as F import models.archs.arch_util as arch_util from models.archs.arch_util import PixelUnshuffle import torchvision -import switched_conv as switched_conv class ResidualDenseBlock_5C(nn.Module): @@ -32,91 +31,6 @@ class ResidualDenseBlock_5C(nn.Module): return x5 * 0.2 + x -# Multiple 5-channel residual block that uses learned switching to diversify its outputs. -# If multi_head_input=False: takes standard (b,f,w,h) input tensor; else takes (b,heads,f,w,h) input tensor. Note that the default RDB block does not support this format, so use SwitchedRDB_5C_MultiHead for this case. -# If collapse_heads=True, outputs (b,f,w,h) tensor. -# If collapse_heads=False, outputs (b,heads,f,w,h) tensor. -class SwitchedRDB_5C(switched_conv.MultiHeadSwitchedAbstractBlock): - def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, multi_head_input=False, collapse_heads=True, force_block=None): - if force_block is None: - rdb5c = functools.partial(ResidualDenseBlock_5C, nf, gc) - else: - rdb5c = force_block - super(SwitchedRDB_5C, self).__init__( - rdb5c, - nf, - num_convs, - num_heads, - att_kernel_size=3, - att_pads=1, - include_skip_head=include_skip_head, - initial_temperature=init_temperature, - multi_head_input=multi_head_input, - concat_heads_into_filters=collapse_heads, - ) - self.collapse_heads = collapse_heads - if self.collapse_heads: - self.mhead_collapse = nn.Conv2d(num_heads * nf, nf, 1) - arch_util.initialize_weights([self.mhead_collapse], 1) - - arch_util.initialize_weights([sw.attention_conv1 for sw in self.switches] + - [sw.attention_conv2 for sw in self.switches], 1) - - def forward(self, x, output_attention_weights=False): - outs = super(SwitchedRDB_5C, self).forward(x, output_attention_weights) - if output_attention_weights: - outs, atts = outs - - if self.collapse_heads: - # outs need to be collapsed back down to a single heads worth of data. - out = self.mhead_collapse(outs) - else: - out = outs - - return out, atts - - -# Implementation of ResidualDenseBlock_5C which compresses multiple switching heads via a Conv3d before doing RDB -# computation. -class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C): - def __init__(self, nf=64, gc=32, bias=True, heads=2): - # Switched blocks generally operate at low resolution, kernel size is much less important, therefore set to 1. - super(ResidualDenseBlock_5C_WithMheadConverter, self).__init__(nf=nf, gc=gc, bias=bias, late_stage_kernel_size=1, - late_stage_padding=0) - self.heads = heads - self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1)) - arch_util.initialize_weights(self.converter) - - # Accepts input of shape (b, heads, f, w, h) - def forward(self, x): - # Permute filter dim to 1. - x = x.permute(0, 2, 1, 3, 4) - x = self.converter(x) - x = torch.squeeze(x, dim=2) - return super(ResidualDenseBlock_5C_WithMheadConverter, self).forward(x) - - -# Multiple 5-channel residual block that uses learned switching to diversify its outputs. The difference between this -# block and SwitchedRDB_5C is this block accepts multi-headed inputs of format (b,heads,f,w,h). -# -# It does this by performing a Conv3d on the first block, which convolves all heads and collapses them to a dimension -# of 1. The tensor is then squeezed and performs identically to SwitchedRDB_5C from there. -class SwitchedRDB_5C_MultiHead(SwitchedRDB_5C): - def __init__(self, nf=64, gc=32, num_convs=8, num_heads=2, include_skip_head=False, init_temperature=1, collapse_heads=False): - rdb5c = functools.partial(ResidualDenseBlock_5C_WithMheadConverter, nf, gc, heads=num_heads) - super(SwitchedRDB_5C_MultiHead, self).__init__( - nf=nf, - gc=gc, - num_convs=num_convs, - num_heads=num_heads, - include_skip_head=include_skip_head, - init_temperature=init_temperature, - multi_head_input=True, - collapse_heads=collapse_heads, - force_block=rdb5c, - ) - - class RRDB(nn.Module): '''Residual in Residual Dense Block''' @@ -145,49 +59,6 @@ class LowDimRRDB(RRDB): return self.shuffle(x) -# RRDB block that uses switching on the individual RDB modules that compose it to increase learning diversity. -class SwitchedRRDB(RRDB): - def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1, switching_block=SwitchedRDB_5C): - super(SwitchedRRDB, self).__init__(nf, gc) - self.RDB1 = switching_block(nf, gc, num_convs=num_convs, init_temperature=init_temperature) - self.RDB2 = switching_block(nf, gc, num_convs=num_convs, init_temperature=init_temperature) - self.RDB3 = switching_block(nf, gc, num_convs=num_convs, init_temperature=init_temperature) - self.init_temperature = init_temperature - self.final_temperature_step = final_temperature_step - self.running_mean = 0 - self.running_count = 0 - - def set_temperature(self, temp): - [sw.set_attention_temperature(temp) for sw in self.RDB1.switches] - [sw.set_attention_temperature(temp) for sw in self.RDB2.switches] - [sw.set_attention_temperature(temp) for sw in self.RDB3.switches] - - def forward(self, x): - out, att1 = self.RDB1(x, True) - out, att2 = self.RDB2(out, True) - out, att3 = self.RDB3(out, True) - - a1mean, _ = switched_conv.compute_attention_specificity(att1, 2) - a2mean, _ = switched_conv.compute_attention_specificity(att2, 2) - a3mean, _ = switched_conv.compute_attention_specificity(att3, 2) - self.running_mean += (a1mean + a2mean + a3mean) / 3.0 - self.running_count += 1 - - return out * 0.2 + x - - def get_debug_values(self, step, prefix): - # Take the chance to update the temperature here. - temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) - self.set_temperature(temp) - - # Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced. - val = {"%s_attention_mean" % (prefix,): self.running_mean / self.running_count, - "attention_temperature": temp} - self.running_count = 0 - self.running_mean = 0 - return val - - # Identical to LowDimRRDB but wraps an RRDB rather than inheriting from it. TODO: remove LowDimRRDB when backwards # compatibility is no longer desired. class LowDimRRDBWrapper(nn.Module): @@ -203,36 +74,6 @@ class LowDimRRDBWrapper(nn.Module): x = self.rrdb(x) return self.shuffle(x) -# RRDB block that uses multi-headed switching on multiple individual RDB blocks to improve diversity. Multiple heads -# are annealed internally. This variant has a depth of 4 RDB blocks, rather than 3 like others above. -class SwitchedMultiHeadRRDB(SwitchedRRDB): - def __init__(self, nf, gc=32, num_convs=8, num_heads=2, init_temperature=1, final_temperature_step=1): - super(SwitchedMultiHeadRRDB, self).__init__(nf=nf, gc=gc, num_convs=num_convs, init_temperature=init_temperature, final_temperature_step=final_temperature_step) - self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) - self.RDB2 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) - self.RDB3 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=False) - self.RDB4 = SwitchedRDB_5C_MultiHead(nf, gc, num_convs=num_convs, num_heads=num_heads, include_skip_head=True, init_temperature=init_temperature, collapse_heads=True) - - def set_temperature(self, temp): - [sw.set_attention_temperature(temp) for sw in self.RDB1.switches] - [sw.set_attention_temperature(temp) for sw in self.RDB2.switches] - [sw.set_attention_temperature(temp) for sw in self.RDB3.switches] - [sw.set_attention_temperature(temp) for sw in self.RDB4.switches] - - def forward(self, x): - out, att1 = self.RDB1(x, True) - out, att2 = self.RDB2(out, True) - out, att3 = self.RDB3(out, True) - out, att4 = self.RDB4(out, True) - - a1mean, _ = switched_conv.compute_attention_specificity(att1, 2) - a2mean, _ = switched_conv.compute_attention_specificity(att2, 2) - a3mean, _ = switched_conv.compute_attention_specificity(att3, 2) - a4mean, _ = switched_conv.compute_attention_specificity(att4, 2) - self.running_mean += (a1mean + a2mean + a3mean + a4mean) / 3.0 - self.running_count += 1 - - return out * 0.2 + x # This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end. class RRDBTrunk(nn.Module): @@ -270,6 +111,7 @@ class RRDBTrunk(nn.Module): i += 1 return val + # Adds some base methods that all RRDB* classes will use. class RRDBBase(nn.Module): def __init__(self): @@ -331,6 +173,7 @@ class RRDBNet(RRDBBase): state_dict["trunk.%s" % (k,)] = state_dict.pop(k) super(RRDBNet, self).load_state_dict(state_dict, strict) + # Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose # intermediate layers have been splayed out, pixel-shuffled, and fed back in. # TODO: Convert to use new RRDBBase hierarchy. @@ -413,6 +256,7 @@ class AssistedRRDBNet(nn.Module): return (out,) + class PixShuffleInitialConv(nn.Module): def __init__(self, reduction_factor, nf_out): super(PixShuffleInitialConv, self).__init__() @@ -427,6 +271,7 @@ class PixShuffleInitialConv(nn.Module): x = self.unshuffle(x) return self.conv(x) + # This class uses a RRDBTrunk to perform processing on an image, then upsamples it. class PixShuffleRRDB(RRDBBase): def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None): diff --git a/codes/models/networks.py b/codes/models/networks.py index e2740134..e6f32efc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -32,25 +32,11 @@ def define_G(opt, net_key='network_G'): elif which_model == 'AssistedRRDBNet': netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=scale) - elif which_model == 'AttentiveRRDBNet': - netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=scale, - rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'], - init_temperature=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step'])) elif which_model == 'LowDimRRDBNet': gen_scale = scale * opt_net['initial_stride'] rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim']) netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) - elif which_model == "LowDimRRDBWithMultiHeadSwitching": - gen_scale = scale * opt_net['initial_stride'] - switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'], - init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) - rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'], - partial_rrdb=switcher) - netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride']) elif which_model == 'PixRRDBNet': block_f = None if opt_net['attention']: