diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 926380d0..8c0bb87b 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -241,7 +241,7 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)] state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)] state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] - super(DualOutputSRG, self).load_state_dict(state_dict, strict) + super(ConfigurableSwitchedResidualGenerator2, self).load_state_dict(state_dict, strict) class Interpolate(nn.Module):