diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ec28cd96..6fb49c13 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -614,7 +614,95 @@ class SwitchModelBase(nn.Module): return val -if __name__ == '__main__': - tbs = TheBigSwitch(3, 64) - x = torch.randn(4,3,64,64) - b = tbs(x) \ No newline at end of file +class ConfigurableSwitchedResidualGenerator2(nn.Module): + def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000, upsample_factor=1, + add_scalable_noise_to_transforms=False): + super(ConfigurableSwitchedResidualGenerator2, self).__init__() + switches = [] + self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, norm=False, activation=False, bias=True) + self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) + self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) + for _ in range(switch_depth): + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) + pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) + switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=attention_norm, + transform_count=trans_counts, init_temp=initial_temp, + add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) + + self.switches = nn.ModuleList(switches) + self.transformation_counts = trans_counts + self.init_temperature = initial_temp + self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step + self.attentions = None + self.upsample_factor = upsample_factor + self.lr = None + assert self.upsample_factor == 2 or self.upsample_factor == 4 + + def forward(self, x): + self.lr = x.detach().cpu() + + # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. + if not self.train: + assert self.switches[0].switch.temperature == 1 + + x = self.initial_conv(x) + + self.attentions = [] + for i, sw in enumerate(self.switches): + x, att = checkpoint(sw, x) + self.attentions.append(att) + + x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) + if self.upsample_factor > 2: + x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.upconv2(x) + x = self.final_conv(self.hr_conv(x)) + return x + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, + 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) + if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ + self.heightened_final_step != 1: + # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. + # without this, the attention specificity "spikes" incredibly fast in the last few iterations. + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = min(step - self.final_temperature_step, h_steps_total) + # The "gap" will represent the steps that need to be traveled as a linear function. + h_gap = 1 / self.heightened_temp_min + temp = h_gap * h_steps_current / h_steps_total + # Invert temperature to represent reality on this side of the curve + temp = 1 / temp + self.set_temperature(temp) + if step % 100 == 0: + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.attentions[i].shape[3], prefix % (step, i), step, + output_mag=False) for i in range(len(self.attentions))] + if self.lr is not None: + torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps", + "amap_%i_base_image.png" % (step,))) + + def get_debug_values(self, step, net_name): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val \ No newline at end of file