Fix inverse temperature curve logic and add upsample factor
This commit is contained in:
parent
0551139b8d
commit
61364ec7d0
|
@ -102,7 +102,7 @@ class SwitchComputer(nn.Module):
|
||||||
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, trans_filters_mid, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000):
|
heightened_final_step=50000, upsample_factor=1):
|
||||||
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
for filters, sw_reduce, sw_proc, trans_count, kernel, layers, mid_filters in zip(switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, trans_filters_mid):
|
||||||
|
@ -117,8 +117,14 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
self.heightened_temp_min = heightened_temp_min
|
self.heightened_temp_min = heightened_temp_min
|
||||||
self.heightened_final_step = heightened_final_step
|
self.heightened_final_step = heightened_final_step
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
# This network is entirely a "repair" network and operates on full-resolution images. Upsample first if that
|
||||||
|
# is called for, then repair.
|
||||||
|
if self.upsample_factor > 1:
|
||||||
|
x = F.interpolate(x, scale_factor=self.upsample_factor, mode="nearest")
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
for i, sw in enumerate(self.switches):
|
for i, sw in enumerate(self.switches):
|
||||||
sw_out, att = sw.forward(x, True)
|
sw_out, att = sw.forward(x, True)
|
||||||
|
@ -132,7 +138,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
if self.attentions:
|
if self.attentions:
|
||||||
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||||
if temp == 1:
|
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
||||||
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||||
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||||
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||||
|
@ -155,4 +161,4 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
for i in range(len(means)):
|
for i in range(len(means)):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
return val
|
return val
|
||||||
|
|
|
@ -131,4 +131,5 @@ class BaseModel():
|
||||||
self.optimizers[i].load_state_dict(o)
|
self.optimizers[i].load_state_dict(o)
|
||||||
for i, s in enumerate(resume_schedulers):
|
for i, s in enumerate(resume_schedulers):
|
||||||
self.schedulers[i].load_state_dict(s)
|
self.schedulers[i].load_state_dict(s)
|
||||||
amp.load_state_dict(resume_state['amp'])
|
if 'amp' in resume_state.keys():
|
||||||
|
amp.load_state_dict(resume_state['amp'])
|
||||||
|
|
|
@ -68,7 +68,8 @@ def define_G(opt, net_key='network_G'):
|
||||||
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||||
trans_filters_mid=opt_net['trans_filters_mid'],
|
trans_filters_mid=opt_net['trans_filters_mid'],
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'])
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
|
upsample_factor=scale)
|
||||||
|
|
||||||
# image corruption
|
# image corruption
|
||||||
elif which_model == 'HighToLowResNet':
|
elif which_model == 'HighToLowResNet':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user