Interpolate attention well before softmax

This commit is contained in:
James Betker 2020-07-06 09:18:30 -06:00
parent 72f90cabf8
commit 8f92c0a088

View File

@ -60,12 +60,14 @@ class SwitchComputer(nn.Module):
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
final_filters = filters * 2 ** reduction_blocks
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
proc_block_filters = max(final_filters // 2, transform_count)
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, silu=False, bn=False)
self.interpolate_process = ConvBnSilu(filters, filters)
self.interpolate_process2 = ConvBnSilu(filters, filters)
tc = transform_count
if self.enable_negative_transforms:
tc = transform_count * 2
self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0)
assert filters > transform_count * 2
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
self.add_noise = add_scalable_noise_to_transforms
@ -89,10 +91,13 @@ class SwitchComputer(nn.Module):
multiplexer = block.forward(multiplexer)
for block in self.processing_blocks:
multiplexer = block.forward(multiplexer)
multiplexer = self.proc_switch_conv(multiplexer)
multiplexer = self.final_switch_conv.forward(multiplexer)
# Interpolate the multiplexer across the entire shape of the image.
# Interpolate the multiplexer across the entire shape of the image perform some post-processing before feeding into switch.
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
multiplexer = self.post_interpolate_decimate(multiplexer)
multiplexer = self.interpolate_process(multiplexer)
multiplexer = self.interpolate_process2(multiplexer)
multiplexer = self.final_switch_conv.forward(multiplexer)
outputs, attention = self.switch(xformed, multiplexer, True)
outputs = outputs * self.scale + self.bias