diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 533946a9..f72e26d8 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -522,7 +522,7 @@ class Spsr7(nn.Module): self.final_temperature_step = 10000 self.lr = None - def forward(self, x, ref, ref_center, only_return_final_feature_map=False): + def forward(self, x, ref, ref_center, update_attention_norm=True): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() @@ -543,145 +543,12 @@ class Spsr7(nn.Module): x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding)) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) - if not only_return_final_feature_map: - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) + x_grad_out = self.upsample_grad(x_grad) + x_grad_out = self.grad_branch_output_conv(x_grad_out) x_out = x2 x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding)) - x_out = self.final_lr_conv(x_out) - final_feature_map = x_out - if only_return_final_feature_map: - return final_feature_map - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = self.final_hr_conv2(x_out) - - self.attentions = [a1, a2, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out, final_feature_map - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 500 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp, - "grad_branch_feat_intg_std_dev": self.grad_fea_std, - "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - -# Based on Spsr7 but swaps sw2 to the end of the chain. Also re-enables pretransform convs. -class Spsr8(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): - super(Spsr8, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - # switch options - self.nf = nf - transformation_filters = nf - self.transformation_counts = xforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, embedding_channels=512, reductions=multiplexer_reductions) - pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=3, - weight_init_factor=.1) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) - self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) - - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.final_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw_grad, self.conjoin_sw, self.final_sw] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - self.lr = None - - def forward(self, x, ref, ref_center): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - x_grad = self.get_g_nopadding(x) - ref_code = self.reference_embedding(ref, ref_center) - ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - x1 = x - x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding)) - - x_grad = self.grad_conv(x_grad) - x_grad_identity = x_grad - x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1) - x_grad, a2 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding)) - x_grad = self.grad_lr_conv(x_grad) - x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) - - x_out = x1 - x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) - x_out, a3 = self.conjoin_sw(x_out, True, identity=x1, att_in=(x_out, ref_embedding)) - x_out, a4 = self.final_sw(x_out, True, identity=x_out, att_in=(x_out, ref_embedding)) - x_out = self.final_lr_conv(x_out) x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv1, x_out) @@ -719,3 +586,132 @@ class Spsr8(nn.Module): val["switch_%i_histogram" % (i,)] = hists[i] return val + +class AttentionBlock(nn.Module): + def __init__(self, nf, num_transforms, multiplexer_reductions, init_temperature=10, has_ref=True): + super(AttentionBlock, self).__init__() + self.nf = nf + self.transformation_counts = num_transforms + multiplx_fn = functools.partial(QueryKeyMultiplexer, nf, embedding_channels=512, reductions=multiplexer_reductions) + transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), + nf, kernel_size=3, depth=4, + weight_init_factor=.1) + if has_ref: + self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) + else: + self.ref_join = None + self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + + def forward(self, x, mplex_ref=None, ref=None): + if self.ref_join is not None: + branch, ref_std = self.ref_join(x, ref) + return self.switch(branch, True, identity=x, att_in=(branch, mplex_ref)) + (ref_std,) + else: + return self.switch(x, True, identity=x, att_in=(x, mplex_ref)) + + +# SPSR7 with incremental improvements and also using the new AttentionBlock to save gpu memory. +class Spsr9(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): + super(Spsr9, self).__init__() + n_upscale = int(math.log(upscale, 2)) + self.nf = nf + self.transformation_counts = xforms + + # processing the input embedding + self.reference_embedding = ReferenceImageBranch(nf) + + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) + self.sw1 = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, False) + self.sw2 = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, False) + + # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. + self.get_g_nopadding = ImageGradientNoPadding() + self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) + self.sw_grad = AttentionBlock(nf, self.transformation_counts // 2, multiplexer_reductions, init_temperature, True) + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) + + # Join branch (grad+fea) + self.conjoin_sw = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, True) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + self.switches = [self.sw1.switch, self.sw2.switch, self.sw_grad.switch, self.conjoin_sw.switch] + self.attentions = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + self.lr = None + + def forward(self, x, ref, ref_center, update_attention_norm=True): + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() + + for sw in self.switches: + sw.set_update_attention_norm(update_attention_norm) + + x_grad = self.get_g_nopadding(x) + ref_code = checkpoint(self.reference_embedding, ref, ref_center) + ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) + + x = self.model_fea_conv(x) + x1 = x + x1, a1 = checkpoint(self.sw1, x1, ref_embedding) + x2 = x1 + x2, a2 = checkpoint(self.sw2, x2, ref_embedding) + + x_grad = self.grad_conv(x_grad) + x_grad_identity = x_grad + x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, ref_embedding, x1) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.grad_lr_conv2(x_grad) + x_grad_out = checkpoint(self.upsample_grad, x_grad) + x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) + + x_out = x2 + x_out, a4, fea_grad_std = checkpoint(self.conjoin_sw, x_out, ref_embedding, x_grad) + x_out = self.final_lr_conv(x_out) + x_out = checkpoint(self.upsample, x_out) + x_out = checkpoint(self.final_hr_conv1, x_out) + x_out = checkpoint(self.final_hr_conv2, x_out) + + self.attentions = [a1, a2, a3, a4] + self.grad_fea_std = grad_fea_std.detach().cpu() + self.fea_grad_std = fea_grad_std.detach().cpu() + return x_grad_out, x_out + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 500 == 0: + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) + + def get_debug_values(self, step, net_name): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp, + "grad_branch_feat_intg_std_dev": self.grad_fea_std, + "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 644beeb1..99145063 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -7,14 +7,8 @@ from collections import OrderedDict from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, ConvGnLelu, MultiConvBlock, SiLU from switched_conv_util import save_attention_to_image_rgb import os -from utils.util import checkpoint from models.archs.spinenet_arch import SpineNet - -# Set to true to relieve memory pressure by using utils.util in several memory-critical locations. -memory_checkpointing_enabled = True - - # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation # Doubles the input filter count. class HalvingProcessingBlock(nn.Module): @@ -136,19 +130,13 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - if memory_checkpointing_enabled: - xformed = [checkpoint(t, *x) for t in self.transforms] - else: - xformed = [t(*x) for t in self.transforms] + xformed = [t(*x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - if memory_checkpointing_enabled: - m = checkpoint(self.multiplexer, *att_in) - else: - m = self.multiplexer(*att_in) + m = self.multiplexer(*att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True, self.update_norm) @@ -286,10 +274,10 @@ class BackboneEncoder(nn.Module): # [ref] will have a 'mask' channel which we cannot use with pretrained spinenet. ref = ref[:, :3, :, :] - ref_emb = checkpoint(self.ref_spine, ref)[0] + ref_emb = self.ref_spine(ref)[0] ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location. - patch = checkpoint(self.patch_spine, x)[0] + patch = self.patch_spine(x)[0] ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) combined = self.merge_process2(combined) @@ -316,7 +304,7 @@ class BackboneEncoderNoRef(nn.Module): if self.interpolate_first: x = F.interpolate(x, scale_factor=2, mode="bicubic") - patch = checkpoint(self.patch_spine, x)[0] + patch = self.patch_spine(x)[0] return patch @@ -332,10 +320,10 @@ class BackboneSpinenetNoHead(nn.Module): self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True) def forward(self, x, ref, ref_center_point): - ref_emb = checkpoint(self.ref_spine, ref)[0] + ref_emb = self.ref_spine(ref)[0] ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location. - patch = checkpoint(self.patch_spine, x)[0] + patch = self.patch_spine(x)[0] ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) combined = self.merge_process2(combined) diff --git a/codes/models/networks.py b/codes/models/networks.py index 823be304..933d9775 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -73,9 +73,9 @@ def define_G(opt, net_key='network_G', scale=None): netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == "spsr8": + elif which_model == "spsr9": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.Spsr8(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "ssgr1":