diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 560adde6..5377553a 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -132,6 +132,83 @@ class SwitchComputer(nn.Module): self.switch.set_attention_temperature(temp) +class ConfigurableSwitchComputer(nn.Module): + def __init__(self, multiplexer_net, transform_block, transform_count, init_temp=20, + enable_negative_transforms=False, add_scalable_noise_to_transforms=False): + super(ConfigurableSwitchComputer, self).__init__() + self.enable_negative_transforms = enable_negative_transforms + + tc = transform_count + if self.enable_negative_transforms: + tc = transform_count * 2 + self.multiplexer = multiplexer_net(tc) + + self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) + self.add_noise = add_scalable_noise_to_transforms + + # And the switch itself, including learned scalars + self.switch = BareConvSwitch(initial_temperature=init_temp) + self.scale = nn.Parameter(torch.ones(1)) + self.bias = nn.Parameter(torch.zeros(1)) + + def forward(self, x, output_attention_weights=False): + if self.add_noise: + rand_feature = torch.randn_like(x) + xformed = [t.forward(x, rand_feature) for t in self.transforms] + else: + xformed = [t.forward(x) for t in self.transforms] + if self.enable_negative_transforms: + xformed.extend([-t for t in xformed]) + + m = self.multiplexer(x) + # Interpolate the multiplexer across the entire shape of the image. + m = F.interpolate(m, size=x.shape[2:], mode='nearest') + + outputs, attention = self.switch(xformed, m, True) + outputs = outputs * self.scale + self.bias + if output_attention_weights: + return outputs, attention + else: + return outputs + + def set_temperature(self, temp): + self.switch.set_attention_temperature(temp) + + +class ResidualBasisMultiplexerBase(nn.Module): + def __init__(self, input_channels, base_filters, growth, reductions, processing_depth): + super(ResidualBasisMultiplexerBase, self).__init__() + self.filter_conv = ConvBnLelu(input_channels, base_filters) + self.reduction_blocks = nn.Sequential(OrderedDict([('block%i:' % (i,), HalvingProcessingBlock(base_filters * 2 ** i)) for i in range(reductions)])) + reduction_filters = base_filters * 2 ** reductions + self.processing_blocks, self.output_filter_count = create_sequential_growing_processing_block(reduction_filters, growth, processing_depth) + + def forward(self, x): + x = self.filter_conv(x) + x = self.reduction_blocks(x) + x = self.processing_blocks(x) + return x + + +class ResidualBasisMultiplexerLeaf(nn.Module): + def __init__(self, base, filters, multiplexer_channels, use_bn=False): + super(ResidualBasisMultiplexerLeaf, self).__init__() + assert(filters > multiplexer_channels) + gap = filters - multiplexer_channels + assert(gap % 4 == 0) + self.base = base + self.cbl1 = ConvBnLelu(filters, filters - (gap // 4), bn=use_bn) + self.cbl2 = ConvBnLelu(filters - (gap // 4), filters - (gap // 2), bn=use_bn) + self.cbl3 = ConvBnLelu(filters - (gap // 2), multiplexer_channels) + + def forward(self, x): + x = self.base(x) + x = self.cbl1(x) + x = self.cbl2(x) + x = self.cbl3(x) + return x + + class ConfigurableSwitchedResidualGenerator(nn.Module): def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, @@ -196,3 +273,71 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] return val + + +class ConfigurableSwitchedResidualGenerator2(nn.Module): + def __init__(self, switch_filters, switch_growths, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, enable_negative_transforms=False, + add_scalable_noise_to_transforms=False): + super(ConfigurableSwitchedResidualGenerator2, self).__init__() + switches = [] + multiplexer_base = ResidualBasisMultiplexerBase(3, switch_filters[0], switch_growths[0], switch_reductions[0], switch_processing_layers[0]) + for trans_count, kernel, layers, mid_filters in zip(trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): + leaf_fn = functools.partial(ResidualBasisMultiplexerLeaf, multiplexer_base, multiplexer_base.output_filter_count) + switches.append(ConfigurableSwitchComputer(leaf_fn, functools.partial(ResidualBranch, 3, mid_filters, 3, kernel_size=kernel, depth=layers), trans_count, initial_temp, enable_negative_transforms=enable_negative_transforms, add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + initialize_weights(switches, 1) + # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. + initialize_weights([s.transforms for s in switches], .2 / len(switches)) + self.switches = nn.ModuleList(switches) + self.transformation_counts = trans_counts + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step + self.attentions = None + self.upsample_factor = upsample_factor + + def forward(self, x): + # This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that + # is called for, then repair. + if self.upsample_factor > 1: + x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest") + + self.attentions = [] + for i, sw in enumerate(self.switches): + sw_out, att = sw.forward(x, True) + x = x + sw_out + self.attentions.append(att) + return x, + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) + if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = min(step - self.final_temperature_step, h_steps_total) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + self.set_temperature(temp) + if step % 50 == 0: + [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,)) for i in range(len(self.switches))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 541128eb..2614565f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -71,6 +71,15 @@ def define_G(opt, net_key='network_G'): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == "ConfigurableSwitchedResidualGenerator2": + netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'], + switch_reductions=opt_net['switch_reductions'], + switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], + trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], + trans_filters_mid=opt_net['trans_filters_mid'], + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) # image corruption elif which_model == 'HighToLowResNet':