forked from mrq/DL-Art-School
Fixed ConfigurableSwitchedGenerator bug
This commit is contained in:
parent
7d541642aa
commit
6f8406fbdc
|
@ -86,7 +86,7 @@ class SwitchComputer(nn.Module):
|
||||||
multiplexer = self.proc_switch_conv(multiplexer)
|
multiplexer = self.proc_switch_conv(multiplexer)
|
||||||
multiplexer = self.final_switch_conv.forward(multiplexer)
|
multiplexer = self.final_switch_conv.forward(multiplexer)
|
||||||
# Interpolate the multiplexer across the entire shape of the image.
|
# Interpolate the multiplexer across the entire shape of the image.
|
||||||
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest', recompute_scale_factor=True)
|
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
|
||||||
|
|
||||||
return self.switch(xformed, multiplexer, output_attention_weights)
|
return self.switch(xformed, multiplexer, output_attention_weights)
|
||||||
|
|
||||||
|
@ -111,7 +111,8 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
for i, sw in enumerate(self.switches):
|
for i, sw in enumerate(self.switches):
|
||||||
x, att = sw.forward(x, True)
|
sw_out, att = sw.forward(x, True)
|
||||||
|
x = x + sw_out
|
||||||
self.attentions.append(att)
|
self.attentions.append(att)
|
||||||
return x,
|
return x,
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user