Allow legacy state_dicts in srg2
This commit is contained in:
parent
1b1431133b
commit
853468ef82
|
@ -322,6 +322,17 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
||||||
|
t_state = self.state_dict()
|
||||||
|
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
||||||
|
for i in range(4):
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
||||||
|
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
|
||||||
class DualOutputSRG(nn.Module):
|
class DualOutputSRG(nn.Module):
|
||||||
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
|
@ -411,14 +422,3 @@ class DualOutputSRG(nn.Module):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
|
||||||
t_state = self.state_dict()
|
|
||||||
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
|
||||||
for i in range(4):
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
|
||||||
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
|
||||||
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
|
|
Loading…
Reference in New Issue
Block a user