From 853468ef82c515d45d9e1156a6f43191b126c141 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 14 Jul 2020 10:03:45 -0600 Subject: [PATCH] Allow legacy state_dicts in srg2 --- .../archs/SwitchedResidualGenerator_arch.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index e42ec2a0..8a0b9de8 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -322,6 +322,17 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): return val + def load_state_dict(self, state_dict, strict=True): + # Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict + t_state = self.state_dict() + if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys(): + for i in range(4): + state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)] + state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)] + state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] + super(DualOutputSRG, self).load_state_dict(state_dict, strict) + + class DualOutputSRG(nn.Module): def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, @@ -410,15 +421,4 @@ class DualOutputSRG(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val - - - def load_state_dict(self, state_dict, strict=True): - # Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict - t_state = self.state_dict() - if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys(): - for i in range(4): - state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)] - state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)] - state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] - super(DualOutputSRG, self).load_state_dict(state_dict, strict) \ No newline at end of file + return val \ No newline at end of file