From 0c4c388e15d242a0f2e1a8817e767a36e3132eb1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Jul 2020 10:09:24 -0600 Subject: [PATCH] Remove dualoutputsrg Good idea, didn't pan out. --- .../archs/SwitchedResidualGenerator_arch.py | 92 ------------------- 1 file changed, 92 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 8c0bb87b..f545e3df 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -335,95 +335,3 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): val["switch_%i_histogram" % (i,)] = hists[i] return val - -class DualOutputSRG(nn.Module): - def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, - heightened_final_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False): - super(DualOutputSRG, self).__init__() - switches = [] - self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) - - self.fea_upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.fea_upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.fea_hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.fea_final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - - self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - - for _ in range(switch_depth): - multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) - pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) - switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - transform_count=trans_counts, init_temp=initial_temp, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) - - self.switches = nn.ModuleList(switches) - self.transformation_counts = trans_counts - self.init_temperature = initial_temp - self.final_temperature_step = final_temperature_step - self.heightened_temp_min = heightened_temp_min - self.heightened_final_step = heightened_final_step - self.attentions = None - self.upsample_factor = upsample_factor - assert self.upsample_factor == 2 or self.upsample_factor == 4 - - def forward(self, x): - x = self.initial_conv(x) - - self.attentions = [] - for i, sw in enumerate(self.switches): - x, att = sw.forward(x, True) - self.attentions.append(att) - - if i == len(self.switches)-2: - fea = self.fea_upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - fea = F.interpolate(fea, scale_factor=2, mode="nearest") - fea = self.fea_upconv2(fea) - fea = self.fea_final_conv(self.hr_conv(fea)) - - x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - x = F.interpolate(x, scale_factor=2, mode="nearest") - x = self.upconv2(x) - return fea, self.final_conv(self.hr_conv(x)) - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, - 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) - if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ - self.heightened_final_step != 1: - # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. - # without this, the attention specificity "spikes" incredibly fast in the last few iterations. - h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) - # The "gap" will represent the steps that need to be traveled as a linear function. - h_gap = 1 / self.heightened_temp_min - temp = h_gap * h_steps_current / h_steps_total - # Invert temperature to represent reality on this side of the curve - temp = 1 / temp - self.set_temperature(temp) - if step % 50 == 0: - save_attention_to_image(experiments_path, self.attentions[0], self.transformation_counts, step, "a%i" % (1,), l_mult=10) - - def get_debug_values(self, step): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file