From e6e91a1d75bd74617a20eb83b854bf4af21e6fd8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 24 Jul 2020 20:32:49 -0600 Subject: [PATCH] Add SRG4 Back to the idea that maybe what we need is a hybrid approach between pure switches and RDB. --- .idea/misc.xml | 1 + .../archs/SwitchedResidualGenerator_arch.py | 102 +++++++++++++++++- codes/models/networks.py | 9 ++ 3 files changed, 110 insertions(+), 2 deletions(-) diff --git a/.idea/misc.xml b/.idea/misc.xml index 28a804d8..f67039df 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,4 +3,5 @@ + \ No newline at end of file diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ce12a315..9fb8b2c4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -5,7 +5,7 @@ import torch.nn.functional as F import functools from collections import OrderedDict from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock -from models.archs.RRDBNet_arch import ResidualDenseBlock_5C +from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB from models.archs.spinenet_arch import SpineNet from switched_conv_util import save_attention_to_image @@ -117,7 +117,10 @@ class ConfigurableSwitchComputer(nn.Module): tc = transform_count self.multiplexer = multiplexer_net(tc) - self.pre_transform = pre_transform_block() + if pre_transform_block: + self.pre_transform = pre_transform_block() + else: + self.pre_transform = None self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) self.add_noise = add_scalable_noise_to_transforms self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) @@ -237,6 +240,101 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): val["switch_%i_histogram" % (i,)] = hists[i] return val + +# Equivalent to SRG2 - Uses RDB blocks in between two switches. +class ConfigurableSwitchedResidualGenerator4(nn.Module): + def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, + add_scalable_noise_to_transforms=False): + super(ConfigurableSwitchedResidualGenerator4, self).__init__() + self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) + self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, trans_counts) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, + weight_init_factor=.1) + self.rdb1 = RRDB(transformation_filters) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=attention_norm, + transform_count=trans_counts, init_temp=initial_temp, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) + self.rdb2 = RRDB(transformation_filters) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=attention_norm, + transform_count=trans_counts, init_temp=initial_temp, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms) + self.rdb3 = RRDB(transformation_filters) + + self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) + self.transformation_counts = trans_counts + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step + self.attentions = None + self.upsample_factor = upsample_factor + assert self.upsample_factor == 2 or self.upsample_factor == 4 + + def forward(self, x): + # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. + if not self.train: + assert self.switches[0].switch.temperature == 1 + + x = self.initial_conv(x) + + x = self.rdb1(x) + x = self.sw1(x, True) + x = self.rdb2(x) + x = self.sw2(x, True) + x = self.rdb3(x) + + x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) + if self.upsample_factor > 2: + x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.upconv2(x) + x = self.final_conv(self.hr_conv(x)) + return x, x + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, + 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) + if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ + self.heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = min(step - self.final_temperature_step, h_steps_total) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + self.set_temperature(temp) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val + class Interpolate(nn.Module): def __init__(self, factor): super(Interpolate, self).__init__() diff --git a/codes/models/networks.py b/codes/models/networks.py index d666cc94..a8d61c2b 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -79,6 +79,15 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == "ConfigurableSwitchedResidualGenerator4": + netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator4(switch_filters=opt_net['switch_filters'], + switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], + transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'], + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) elif which_model == "ProgressiveSRG2": netG = psrg.GrowingSRGBase(progressive_step_schedule=opt_net['schedule'], switch_reductions=opt_net['reductions'], growth_fade_in_steps=opt_net['fade_in_steps'], switch_filters=opt_net['switch_filters'],