From 6f8406fbdc213eb568eff5d82b2172e5ca0d4e7f Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 16 Jun 2020 16:53:57 -0600 Subject: [PATCH] Fixed ConfigurableSwitchedGenerator bug --- codes/models/archs/SwitchedResidualGenerator_arch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 0998f8bb..a2a41c99 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -86,7 +86,7 @@ class SwitchComputer(nn.Module): multiplexer = self.proc_switch_conv(multiplexer) multiplexer = self.final_switch_conv.forward(multiplexer) # Interpolate the multiplexer across the entire shape of the image. - multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest', recompute_scale_factor=True) + multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest') return self.switch(xformed, multiplexer, output_attention_weights) @@ -111,7 +111,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): def forward(self, x): self.attentions = [] for i, sw in enumerate(self.switches): - x, att = sw.forward(x, True) + sw_out, att = sw.forward(x, True) + x = x + sw_out self.attentions.append(att) return x,