From 59aba1daa72a55aad650f1912d561bb20c225abb Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 10 Aug 2020 13:03:36 -0600 Subject: [PATCH] LR switched SPSR arch This variant doesn't do conv processing at HR, which should save a ton of memory in inference. Lets see how it works. --- codes/models/archs/SPSR_arch.py | 126 ++++++++++++++++++ .../archs/SwitchedResidualGenerator_arch.py | 6 +- codes/models/networks.py | 2 + 3 files changed, 132 insertions(+), 2 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index dc1aedce..bee33c5e 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -546,3 +546,129 @@ class SwitchedSpsr(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val + + + + +class SwitchedSpsrLr(nn.Module): + def __init__(self, in_nc, out_nc, nf, upscale=4): + super(SwitchedSpsrLr, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + switch_reductions = 3 + switch_processing_layers = 2 + self.transformation_counts = 8 + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, self.transformation_counts) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.branch_upsample = B.sequential(grad_hr_conv1, grad_hr_conv2) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=4, + weight_init_factor=.1) + pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self._branch_pretrain_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = 10 + self.final_temperature_step = 10000 + + def forward(self, x): + x_grad = self.get_g_nopadding(x) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1(x, True) + x2, a2 = self.sw2(x1, True) + x_fea = self.feature_lr_conv(x2) + x_fea = self.feature_hr_conv1(x_fea) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=x1, output_attention_weights=True) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.branch_upsample(x_grad) + x_out_branch = self.upsample_grad(x_grad) + x_out_branch = self.grad_branch_output_conv(x_out_branch) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True) + x_out = self._branch_pretrain_lr_conv(x__branch_pretrain_cat) + x_out = self.upsample(x_out) + x_out = self._branch_pretrain_HR_conv0(x_out) + x_out = self._branch_pretrain_HR_conv1(x_out) + + self.attentions = [a1, a2, a3, a4] + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 10 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index cc681a9f..3548211e 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -134,11 +134,13 @@ class ConfigurableSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - def forward(self, x, output_attention_weights=False, att_in=None, fixed_scale=1): + def forward(self, x, output_attention_weights=False, identity=None, att_in=None, fixed_scale=1): if att_in is None: att_in = x - identity = x + if identity is None: + identity = x + if self.add_noise: rand_feature = torch.randn_like(x) * self.noise_scale x = x + rand_feature diff --git a/codes/models/networks.py b/codes/models/networks.py index 85397f33..7c7a7c8e 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -113,6 +113,8 @@ def define_G(opt, net_key='network_G'): nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == "spsr_switched": netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], upscale=opt_net['scale']) + elif which_model == "spsr_switched_lr": + netG = spsr.SwitchedSpsrLr(in_nc=3, out_nc=3, nf=opt_net['nf'], upscale=opt_net['scale']) # image corruption elif which_model == 'HighToLowResNet':