Fix inverse temperature curve logic and add upsample factor

This commit is contained in:
James Betker 2020-06-19 09:18:30 -06:00
parent 0551139b8d
commit 61364ec7d0
3 changed files with 13 additions and 5 deletions

View File

@ -102,7 +102,7 @@ class SwitchComputer(nn.Module):
class ConfigurableSwitchedResidualGenerator(nn.Module):
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000):
heightened_final_step=50000, upsample_factor=1):
super(ConfigurableSwitchedResidualGenerator, self).__init__()
switches = []
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
@ -117,8 +117,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
def forward(self, x):
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
# is called for, then repair.
if self.upsample_factor > 1:
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
self.attentions = []
for i, sw in enumerate(self.switches):
sw_out, att = sw.forward(x, True)
@ -132,7 +138,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1:
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
@ -155,4 +161,4 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
return val

View File

@ -131,4 +131,5 @@ class BaseModel():
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
amp.load_state_dict(resume_state['amp'])
if 'amp' in resume_state.keys():
amp.load_state_dict(resume_state['amp'])

View File

@ -68,7 +68,8 @@ def define_G(opt, net_key='network_G'):
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
trans_filters_mid=opt_net['trans_filters_mid'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'])
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale)
# image corruption
elif which_model == 'HighToLowResNet':