diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 0998f8bb..a2a41c99 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -86,7 +86,7 @@ class SwitchComputer(nn.Module): multiplexer = self.proc_switch_conv(multiplexer) multiplexer = self.final_switch_conv.forward(multiplexer) # Interpolate the multiplexer across the entire shape of the image. - multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest', recompute_scale_factor=True) + multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest') return self.switch(xformed, multiplexer, output_attention_weights) @@ -111,7 +111,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): def forward(self, x): self.attentions = [] for i, sw in enumerate(self.switches): - x, att = sw.forward(x, True) + sw_out, att = sw.forward(x, True) + x = x + sw_out self.attentions.append(att) return x,