Add a double-step to attention temperature

This commit is contained in:
James Betker 2020-06-18 11:29:31 -06:00
parent d2d5e097d5
commit 778e7b6931
2 changed files with 14 additions and 3 deletions

View File

@ -100,7 +100,9 @@ class SwitchComputer(nn.Module):
class ConfigurableSwitchedResidualGenerator(nn.Module):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000):
super(ConfigurableSwitchedResidualGenerator, self).__init__()
switches = []
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
@ -112,6 +114,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
def forward(self, x):
@ -128,8 +132,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1:
# Enter the linear function regime between self.final_temperature_step and self.heightened_final_step
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
h_gap = 1 - self.heightened_temp_min
temp = 1 - h_gap * h_steps_current / h_steps_total
self.set_temperature(temp)
if step % 250 == 0:
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,), l_mult=float(self.transformation_counts[i]/4)) for i in range(len(self.switches))]
def get_debug_values(self, step):

View File

@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G'):
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
trans_filters_mid=opt_net['trans_filters_mid'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'])
# image corruption
elif which_model == 'HighToLowResNet':