Fixed ConfigurableSwitchedGenerator bug

This commit is contained in:
James Betker 2020-06-16 16:53:57 -06:00
parent 7d541642aa
commit 6f8406fbdc

View File

@ -86,7 +86,7 @@ class SwitchComputer(nn.Module):
multiplexer = self.proc_switch_conv(multiplexer)
multiplexer = self.final_switch_conv.forward(multiplexer)
# Interpolate the multiplexer across the entire shape of the image.
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest', recompute_scale_factor=True)
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
return self.switch(xformed, multiplexer, output_attention_weights)
@ -111,7 +111,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
def forward(self, x):
self.attentions = []
for i, sw in enumerate(self.switches):
x, att = sw.forward(x, True)
sw_out, att = sw.forward(x, True)
x = x + sw_out
self.attentions.append(att)
return x,