From 12e8fad079d95ca3541fb0aaae58f7a6e4fe8699 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 9 Jun 2020 13:28:55 -0600 Subject: [PATCH] Add serveral new RRDB architectures --- codes/models/archs/RRDBNet_arch.py | 196 +++++++++++++++++++++++------ codes/models/networks.py | 15 +++ 2 files changed, 174 insertions(+), 37 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index ad0f0297..2f56e4f5 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -61,8 +61,6 @@ class RRDB(nn.Module): return out * 0.2 + x class AttentiveRRDB(RRDB): - counter = 0 - def __init__(self, nf, gc=32, num_convs=8, init_temperature=1, final_temperature_step=1): super(RRDB, self).__init__() self.RDB1 = SwitchedRDB_5C(nf, gc, num_convs, init_temperature) @@ -72,8 +70,6 @@ class AttentiveRRDB(RRDB): self.final_temperature_step = final_temperature_step self.running_mean = 0 self.running_count = 0 - self.counter = AttentiveRRDB.counter - AttentiveRRDB.counter += 1 def set_temperature(self, temp): self.RDB1.switcher.set_attention_temperature(temp) @@ -93,37 +89,28 @@ class AttentiveRRDB(RRDB): return out * 0.2 + x - def get_debug_values(self, step): + def get_debug_values(self, step, prefix): # Take the chance to update the temperature here. temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) self.set_temperature(temp) # Intentionally overwrite attention_temperature from other RRDB blocks; these should be synced. - val = {"RRDB_%i_attention_mean" % (self.counter,): self.running_mean / self.running_count, + val = {"%s_attention_mean" % (prefix,): self.running_mean / self.running_count, "attention_temperature": temp} self.running_count = 0 self.running_mean = 0 return val -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1, - rrdb_block_f=None): - super(RRDBNet, self).__init__() +# This module performs the majority of the processing done by RRDBNet. It just doesn't have the upsampling at the end. +class RRDBTrunk(nn.Module): + def __init__(self, nf_in, nf_out, nb, gc=32, initial_stride=1, rrdb_block_f=None): + super(RRDBTrunk, self).__init__() if rrdb_block_f is None: - rrdb_block_f = functools.partial(RRDB, nf=nf, gc=gc) + rrdb_block_f = functools.partial(RRDB, nf=nf_out, gc=gc) - self.scale = scale - self.conv_first = nn.Conv2d(in_nc, nf, 7, initial_stride, padding=3, bias=True) + self.conv_first = nn.Conv2d(nf_in, nf_out, 7, initial_stride, padding=3, bias=True) self.RRDB_trunk, self.rrdb_layers = arch_util.make_layer(rrdb_block_f, nb, True) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.trunk_conv = nn.Conv2d(nf_out, nf_out, 3, 1, 1, bias=True) # Sets the softmax temperature of each RRDB layer. Only works if you are using attentive # convolutions. @@ -135,6 +122,58 @@ class RRDBNet(nn.Module): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk + return fea + + def get_debug_values(self, step, prefix): + val = {} + i = 0 + for block in self.RRDB_trunk._modules.values(): + if hasattr(block, "get_debug_values"): + val.update(block.get_debug_values(step, "%s_rdb_%i" % (prefix, i))) + i += 1 + return val + +# Adds some base methods that all RRDB* classes will use. +class RRDBBase(nn.Module): + def __init__(self): + super(RRDBBase, self).__init__() + + # Sets the softmax temperature of each RRDB layer. Only works if you are using attentive + # convolutions. + def set_temperature(self, temp): + for trunk in self.trunks: + for layer in trunk.rrdb_layers: + layer.set_temperature(temp) + + def get_debug_values(self, step): + val = {} + for i, trunk in enumerate(self.trunks): + for j, block in enumerate(trunk.RRDB_trunk._modules.values()): + if hasattr(block, "get_debug_values"): + val.update(block.get_debug_values(step, "trunk_%i_block_%i" % (i, j))) + return val + +# This class uses a RRDBTrunk to perform processing on an image, then upsamples it. +class RRDBNet(RRDBBase): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=2, initial_stride=1, + rrdb_block_f=None): + super(RRDBNet, self).__init__() + + # Trunk - does actual processing. + self.trunk = RRDBTrunk(in_nc, nf, nb, gc, initial_stride, rrdb_block_f) + self.trunks = [self.trunk] + + # Upsampling + self.scale = scale + self.upconv1 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 5, 1, padding=2, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.trunk(x) if self.scale >= 2: fea = F.interpolate(fea, scale_factor=2, mode='nearest') @@ -146,15 +185,16 @@ class RRDBNet(nn.Module): return (out,) - def get_debug_values(self, step): - val = {} - for block in self.RRDB_trunk._modules.values(): - if hasattr(block, "get_debug_values"): - val.update(block.get_debug_values(step)) - return val + def load_state_dict(self, state_dict, strict=True): + # The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them. + t_state = self.trunk.state_dict() + for k in t_state.keys(): + state_dict["trunk.%s" % (k,)] = state_dict.pop(k) + super(RRDBNet, self).load_state_dict(state_dict, strict) # Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose # intermediate layers have been splayed out, pixel-shuffled, and fed back in. +# TODO: Convert to use new RRDBBase hierarchy. class AssistedRRDBNet(nn.Module): # in_nc=number of input channels. # out_nc=number of output channels. @@ -171,10 +211,9 @@ class AssistedRRDBNet(nn.Module): # Set-up the assist-net, which should do feature extraction for us. self.assistnet = torchvision.models.wide_resnet50_2(pretrained=True) self.set_enable_assistnet_training(False) - assist_nf = [2, 4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks. - self.assist1 = RRDB(nf + assist_nf[0], gc) - self.assist2 = RRDB(nf + sum(assist_nf[:2]), gc) - self.assist3 = RRDB(nf + sum(assist_nf[:3]), gc) + assist_nf = [4, 8, 16] # Fixed for resnet. Re-evaluate if using other networks. + self.assist2 = RRDB(nf + assist_nf[0], gc) + self.assist3 = RRDB(nf + sum(assist_nf[:2]), gc) self.assist4 = RRDB(nf + sum(assist_nf), gc) nf = nf + sum(assist_nf) @@ -195,6 +234,11 @@ class AssistedRRDBNet(nn.Module): p.requires_grad = en def res_extract(self, x): + # Width and height must be factors of 16 to use this architecture. Check that here. + (b, f, w, h) = x.shape + assert w % 16 == 0 + assert h % 16 == 0 + x = self.assistnet.conv1(x) x = self.assistnet.bn1(x) x = self.assistnet.relu(x) @@ -206,16 +250,13 @@ class AssistedRRDBNet(nn.Module): l2 = F.pixel_shuffle(x, 8) x = self.assistnet.layer3(x) l3 = F.pixel_shuffle(x, 16) - x = self.assistnet.layer4(x) - l4 = F.pixel_shuffle(x, 32) - return l1, l2, l3, l4 + return l1, l2, l3 def forward(self, x): # Invoke the assistant net first. - l1, l2, l3, l4 = self.res_extract(x) + l1, l2, l3 = self.res_extract(x) fea = self.conv_first(x) - fea = self.assist1(torch.cat([fea, l4], dim=1)) fea = self.assist2(torch.cat([fea, l3], dim=1)) fea = self.assist3(torch.cat([fea, l2], dim=1)) fea = self.assist4(torch.cat([fea, l1], dim=1)) @@ -231,4 +272,85 @@ class AssistedRRDBNet(nn.Module): fea = self.lrelu(self.upconv2(fea)) out = self.conv_last(self.lrelu(self.HRconv(fea))) + return (out,) + + +# This class uses a RRDBTrunk to perform processing on an image, then upsamples it. +class PixShuffleRRDB(RRDBBase): + def __init__(self, nf, nb, gc=32, scale=2, rrdb_block_f=None): + super(PixShuffleRRDB, self).__init__() + + # This class does a 4x pixel shuffle on the filter count inside the trunk, so nf must be divisible by 16. + assert nf % 16 == 0 + + # Trunk - does actual processing. + self.trunk = RRDBTrunk(3, nf, nb, gc, 4, rrdb_block_f) + self.trunks = [self.trunk] + + # Upsampling + pix_nf = int(nf/16) + self.scale = scale + self.upconv1 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True) + self.upconv2 = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True) + self.HRconv = nn.Conv2d(pix_nf, pix_nf, 5, 1, padding=2, bias=True) + self.conv_last = nn.Conv2d(pix_nf, 3, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(4) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.trunk(x) + fea = self.pixel_shuffle(fea) + + if self.scale >= 2: + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv1(fea)) + if self.scale >= 4: + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv2(fea)) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return (out,) + + +# This class uses two RRDB trunks to process an image at different resolution levels. +class MultiRRDBNet(RRDBBase): + def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None): + super(MultiRRDBNet, self).__init__() + + # Initial downsampling. + self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True) + + # Chained trunks + lo_nf = nf_base * 4 + hi_nf = nf_base + self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f) + self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f) + self.trunks = [self.lo_trunk, self.hi_trunk] + + # Upsampling + self.scale = scale + self.upconv1 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True) + self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True) + self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True) + self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True) + self.pixel_shuffle = nn.PixelShuffle(2) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + fea_lo = self.lo_trunk(fea) + fea = self.pixel_shuffle(fea_lo) + fea + fea = self.hi_trunk(fea) + + # First, return image to original size and perform post-processing. + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv1(fea)) + + # If 2x scaling is specified, do that too. + if self.scale >= 2: + fea = F.interpolate(fea, scale_factor=2, mode='nearest') + fea = self.lrelu(self.upconv2(fea)) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + return (out,) \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 0671c6a8..702a989a 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -38,6 +38,21 @@ def define_G(opt, net_key='network_G'): rrdb_block_f=functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])) + elif which_model == 'MultiRRDBNet': + block_f = None + if opt_net['attention']: + block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step']) + netG = RRDBNet_arch.MultiRRDBNet(nf_base=opt_net['nf'], gc_base=opt_net['gc'], lo_blocks=opt_net['lo_blocks'], + hi_blocks=opt_net['hi_blocks'], scale=scale, rrdb_block_f=block_f) + elif which_model == 'PixRRDBNet': + block_f = None + if opt_net['attention']: + block_f = functools.partial(RRDBNet_arch.AttentiveRRDB, nf=opt_net['nf'], gc=opt_net['gc'], + init_temperature=opt_net['temperature'], + final_temperature_step=opt_net['temperature_final_step']) + netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])