diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index b536490e..ad23a5fd 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -10,7 +10,7 @@ from utils.util import checkpoint from models.archs import SPSR_util as B from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, ReferenceImageBranch, \ - QueryKeyMultiplexer, QueryKeyPyramidMultiplexer + QueryKeyMultiplexer, QueryKeyPyramidMultiplexer, ConvBasisMultiplexer from models.archs.arch_util import ConvGnLelu, UpconvBlock, MultiConvBlock, ReferenceJoinBlock from switched_conv.switched_conv import compute_attention_specificity from switched_conv.switched_conv_util import save_attention_to_image_rgb @@ -225,8 +225,6 @@ class Spsr5(nn.Module): attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() @@ -726,3 +724,125 @@ class Spsr9(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val + + +class SwitchedSpsr(nn.Module): + def __init__(self, in_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(SwitchedSpsr, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + switch_reductions = 3 + switch_processing_layers = 2 + self.transformation_counts = xforms + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, self.transformation_counts, use_exp2=True) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=3, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) + mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions, + switch_processing_layers, self.transformation_counts // 2, use_exp2=True) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts // 2, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.grad_hr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, 3, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + transform_fn_cat = functools.partial(MultiConvBlock, transformation_filters * 2, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=4, + weight_init_factor=.1) + pretransform_fn_cat = functools.partial(ConvGnLelu, transformation_filters * 2, transformation_filters * 2, norm=False, bias=False, weight_init_factor=.1) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=True) + self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)]) + self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=True) for _ in range(n_upscale)]) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) + self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf, 3, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x): + x_grad = self.get_g_nopadding(x) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1(x, do_checkpointing=True) + x2, a2 = self.sw2(x1, do_checkpointing=True) + x_fea = self.feature_lr_conv(x2) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=torch.cat([x1, x_b_fea], dim=1), output_attention_weights=True, do_checkpointing=True) + x_grad = checkpoint(self.grad_lr_conv, x_grad) + x_grad = checkpoint(self.grad_hr_conv, x_grad) + x_out_branch = checkpoint(self.upsample_grad, x_grad) + x_out_branch = self.grad_branch_output_conv(x_out_branch) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, att_in=x_fea, identity=x_fea, output_attention_weights=True) + x_out = checkpoint(self.final_lr_conv, x__branch_pretrain_cat) + x_out = checkpoint(self.upsample, x_out) + x_out = checkpoint(self.final_hr_conv1, x_out) + x_out = self.final_hr_conv2(x_out) + + self.attentions = [a1, a2, a3, a4] + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step, net): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val \ No newline at end of file diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index eb49aa62..7545d2d4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -79,7 +79,8 @@ def gather_2d(input, index): class ConfigurableSwitchComputer(nn.Module): def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, - init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): + init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, post_switch_conv=True, + anorm_multiplier=16): super(ConfigurableSwitchComputer, self).__init__() tc = transform_count @@ -95,12 +96,15 @@ class ConfigurableSwitchComputer(nn.Module): self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) # And the switch itself, including learned scalars - self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None) + self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=anorm_multiplier * transform_count) if attention_norm else None) self.switch_scale = nn.Parameter(torch.full((1,), float(1))) - self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) - # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) - # depending on its needs. - self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + if post_switch_conv: + self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) + # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) + # depending on its needs. + self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + else: + self.post_switch_conv = None self.update_norm = True def set_update_attention_norm(self, set_val): @@ -151,7 +155,8 @@ class ConfigurableSwitchComputer(nn.Module): # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention, att_logits = self.switch(xformed, m, True, self.update_norm, output_attention_logits=True) outputs = identity + outputs * self.switch_scale * fixed_scale - outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale + if self.post_switch_conv is not None: + outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale if output_attention_weights: if output_att_logits: return outputs, attention, att_logits @@ -642,7 +647,8 @@ class TheBigSwitch(SwitchModelBase): pre_transform_block=None, transform_block=transform_fn, attention_norm=True, transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, + anorm_multiplier=128) self.switches = [self.switch] self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) @@ -670,6 +676,64 @@ class TheBigSwitch(SwitchModelBase): return x_out, +class ArtistMultiplexer(nn.Module): + def __init__(self, in_nc, nf, multiplexer_channels): + super(ArtistMultiplexer, self).__init__() + + self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False) + self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False) + self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False) + self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2) + self.tail_join = ReferenceJoinBlock(nf) + + self.reduce = ConvGnSilu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=False) + self.last_process = ConvGnSilu(nf // 2, nf // 2, kernel_size=1, activation=True, norm=False, bias=False) + self.to_attention = ConvGnSilu(nf // 2, multiplexer_channels, kernel_size=1, activation=False, norm=False, bias=False) + + def forward(self, x, transformations): + s = self.spine(x)[0] + tail = self.fea_tail(x) + tail = self.tail_proc(tail) + q = F.interpolate(s, scale_factor=2, mode='nearest') + q = self.spine_red_proc(q) + q, _ = self.tail_join(q, tail) + q = self.reduce(q) + q = F.interpolate(q, scale_factor=2, mode='nearest') + return self.to_attention(self.last_process(q)) + + +class ArtistGen(SwitchModelBase): + def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10): + super(ArtistGen, self).__init__(init_temperature, 10000) + self.nf = nf + self.transformation_counts = xforms + + multiplx_fn = functools.partial(ArtistMultiplexer, in_nc, nf) + transform_fn = functools.partial(MultiConvBlock, in_nc, int(in_nc * 2), in_nc, kernel_size=3, depth=4, weight_init_factor=.1) + self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, + pre_transform_block=None, transform_block=transform_fn, + attention_norm=True, + transform_count=self.transformation_counts, init_temp=init_temperature, + add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, + anorm_multiplier=128, post_switch_conv=False) + self.switches = [self.switch] + + def forward(self, x, save_attentions=True): + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() + + # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention + # norm should only be getting updates with new data, not recurrent generator sampling. + for sw in self.switches: + sw.set_update_attention_norm(save_attentions) + + up = F.interpolate(x, scale_factor=2, mode="bicubic") + out, a1, att_logits = self.switch(up, att_in=x, do_checkpointing=True, output_att_logits=True) + + if save_attentions: + self.attentions = [a1] + return out, att_logits.permute(0,3,1,2) + if __name__ == '__main__': tbs = TheBigSwitch(3, 64) x = torch.randn(4,3,64,64) diff --git a/codes/models/networks.py b/codes/models/networks.py index e831d690..fe3e928f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -67,6 +67,8 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'spsr_net_improved': netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) + elif which_model == "spsr_switched": + netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) elif which_model == "spsr5": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.Spsr5(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], @@ -112,6 +114,9 @@ def define_G(opt, net_key='network_G', scale=None): elif which_model == 'big_switch': netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) + elif which_model == 'artist': + netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], + init_temperature=opt_net['temperature']) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path'])