diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 29996bd9..7a4acbc1 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -470,3 +470,127 @@ class SwitchedSpsrWithRef(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val + + +class SwitchedSpsrWithRef4x(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(SwitchedSpsrWithRef4x, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + self.transformation_counts = xforms + self.reference_processor = ReferenceImageBranch(transformation_filters) + multiplx_fn = functools.partial(ReferencingConvMultiplexer, transformation_filters, switch_filters, self.transformation_counts) + pretransform_fn = functools.partial(AdaInConvBlock, 512, transformation_filters, transformation_filters) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.stage1_up_fea = UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + mplex_grad = functools.partial(ReferencingConvMultiplexer, nf * 2, nf * 2, self.transformation_counts // 2) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts // 2, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.stage1_up_grad = UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) + + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=4, + weight_init_factor=.1) + pretransform_fn_cat = functools.partial(AdaInConvBlock, 512, transformation_filters * 2, transformation_filters * 2) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.stage2_up_fea = UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) + self.stage2_up_grad = UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) + self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x, ref, center_coord): + x_grad = self.get_g_nopadding(x) + ref = self.reference_processor(ref, center_coord) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1((x, ref), True) + x2, a2 = self.sw2((x1, ref), True) + x_fea = self.feature_lr_conv(x2) + x_fea = self.stage1_up_fea(x_fea) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad((x_b_fea, ref), att_in=(torch.cat([x1, x_b_fea], dim=1), ref), output_attention_weights=True) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.stage1_up_grad(x_grad) + x_grad = self.grad_hr_conv(x_grad) + x_out_branch = self.stage2_up_grad(x_grad) + x_out_branch = self.grad_branch_output_conv(x_out_branch) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw((x__branch_pretrain_cat, ref), att_in=(x_fea, ref), identity=x_fea, output_attention_weights=True) + x_out = self.final_lr_conv(x__branch_pretrain_cat) + x_out = self.stage2_up_fea(x_out) + x_out = self.final_hr_conv1(x_out) + x_out = self.final_hr_conv2(x_out) + + self.attentions = [a1, a2, a3, a4] + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val diff --git a/codes/models/networks.py b/codes/models/networks.py index 78196ef1..6ebb7935 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -126,6 +126,10 @@ def define_G(opt, net_key='network_G', scale=None): xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.SwitchedSpsrWithRef(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + elif which_model == "spsr_switched_with_ref4x": + xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 + netG = spsr.SwitchedSpsrWithRef4x(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) # image corruption elif which_model == 'HighToLowResNet':