diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 76feccd2..6aa44df2 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -100,7 +100,9 @@ class SwitchComputer(nn.Module): class ConfigurableSwitchedResidualGenerator(nn.Module): - def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000): + def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, + trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, + heightened_final_step=50000): super(ConfigurableSwitchedResidualGenerator, self).__init__() switches = [] for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid): @@ -112,6 +114,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): self.transformation_counts = trans_counts self.init_temperature = initial_temp self.final_temperature_step = final_temperature_step + self.heightened_temp_min = heightened_temp_min + self.heightened_final_step = heightened_final_step self.attentions = None def forward(self, x): @@ -128,8 +132,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): def update_for_step(self, step, experiments_path='.'): if self.attentions: temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step)) + if temp == 1: + # Enter the linear function regime between self.final_temperature_step and self.heightened_final_step + h_steps_total = self.heightened_final_step - self.final_temperature_step + h_steps_current = min(step - self.final_temperature_step, h_steps_total) + h_gap = 1 - self.heightened_temp_min + temp = 1 - h_gap * h_steps_current / h_steps_total self.set_temperature(temp) - if step % 250 == 0: + if step % 50 == 0: [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,), l_mult=float(self.transformation_counts[i]/4)) for i in range(len(self.switches))] def get_debug_values(self, step): diff --git a/codes/models/networks.py b/codes/models/networks.py index e5cccbcf..abcc98cd 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G'): switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'], trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'], trans_filters_mid=opt_net['trans_filters_mid'], - initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step']) + initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], + heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step']) # image corruption elif which_model == 'HighToLowResNet':