Fixed ConfigurableSwitchedGenerator bug
This commit is contained in:
parent
7d541642aa
commit
6f8406fbdc
|
@ -86,7 +86,7 @@ class SwitchComputer(nn.Module):
|
|||
multiplexer = self.proc_switch_conv(multiplexer)
|
||||
multiplexer = self.final_switch_conv.forward(multiplexer)
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest', recompute_scale_factor=True)
|
||||
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
|
||||
|
||||
return self.switch(xformed, multiplexer, output_attention_weights)
|
||||
|
||||
|
@ -111,7 +111,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
def forward(self, x):
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
x, att = sw.forward(x, True)
|
||||
sw_out, att = sw.forward(x, True)
|
||||
x = x + sw_out
|
||||
self.attentions.append(att)
|
||||
return x,
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user