Allow legacy state_dicts in srg2

This commit is contained in:
James Betker 2020-07-14 10:03:45 -06:00
parent 1b1431133b
commit 853468ef82

View File

@ -322,6 +322,17 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
return val
def load_state_dict(self, state_dict, strict=True):
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
t_state = self.state_dict()
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
for i in range(4):
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
class DualOutputSRG(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
@ -411,14 +422,3 @@ class DualOutputSRG(nn.Module):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
def load_state_dict(self, state_dict, strict=True):
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
t_state = self.state_dict()
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
for i in range(4):
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
super(DualOutputSRG, self).load_state_dict(state_dict, strict)