Fix inverse temperature curve logic and add upsample factor
This commit is contained in:
parent
0551139b8d
commit
61364ec7d0
|
@ -102,7 +102,7 @@ class SwitchComputer(nn.Module):
|
|||
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000):
|
||||
heightened_final_step=50000, upsample_factor=1):
|
||||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||
switches = []
|
||||
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||
|
@ -117,8 +117,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
self.heightened_temp_min = heightened_temp_min
|
||||
self.heightened_final_step = heightened_final_step
|
||||
self.attentions = None
|
||||
self.upsample_factor = upsample_factor
|
||||
|
||||
def forward(self, x):
|
||||
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||
# is called for, then repair.
|
||||
if self.upsample_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
sw_out, att = sw.forward(x, True)
|
||||
|
@ -132,7 +138,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||
if temp == 1:
|
||||
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||
|
@ -155,4 +161,4 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
return val
|
||||
|
|
|
@ -131,4 +131,5 @@ class BaseModel():
|
|||
self.optimizers[i].load_state_dict(o)
|
||||
for i, s in enumerate(resume_schedulers):
|
||||
self.schedulers[i].load_state_dict(s)
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
if 'amp' in resume_state.keys():
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
|
|
|
@ -68,7 +68,8 @@ def define_G(opt, net_key='network_G'):
|
|||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'])
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale)
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
|
Loading…
Reference in New Issue
Block a user