Allow legacy state_dicts in srg2
This commit is contained in:
parent
1b1431133b
commit
853468ef82
|
@ -322,6 +322,17 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
|||
return val
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
||||
t_state = self.state_dict()
|
||||
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
||||
for i in range(4):
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
||||
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
|
||||
|
||||
|
||||
class DualOutputSRG(nn.Module):
|
||||
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
|
@ -411,14 +422,3 @@ class DualOutputSRG(nn.Module):
|
|||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
||||
t_state = self.state_dict()
|
||||
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
||||
for i in range(4):
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
||||
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
||||
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
|
Loading…
Reference in New Issue
Block a user