Add a double-step to attention temperature
This commit is contained in:
parent
d2d5e097d5
commit
778e7b6931
|
@ -100,7 +100,9 @@ class SwitchComputer(nn.Module):
|
|||
|
||||
|
||||
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000):
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000):
|
||||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||
switches = []
|
||||
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||
|
@ -112,6 +114,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.heightened_temp_min = heightened_temp_min
|
||||
self.heightened_final_step = heightened_final_step
|
||||
self.attentions = None
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -128,8 +132,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
if temp == 1:
|
||||
# Enter the linear function regime between self.final_temperature_step and self.heightened_final_step
|
||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||
h_steps_current = min(step - self.final_temperature_step, h_steps_total)
|
||||
h_gap = 1 - self.heightened_temp_min
|
||||
temp = 1 - h_gap * h_steps_current / h_steps_total
|
||||
self.set_temperature(temp)
|
||||
if step % 250 == 0:
|
||||
if step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts[i], step, "a%i" % (i+1,), l_mult=float(self.transformation_counts[i]/4)) for i in range(len(self.switches))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
|
|
|
@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G'):
|
|||
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'])
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
|
Loading…
Reference in New Issue
Block a user