forked from mrq/DL-Art-School
Interpolate attention well before softmax
This commit is contained in:
parent
72f90cabf8
commit
8f92c0a088
|
@ -60,12 +60,14 @@ class SwitchComputer(nn.Module):
|
|||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
||||
final_filters = filters * 2 ** reduction_blocks
|
||||
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
||||
proc_block_filters = max(final_filters // 2, transform_count)
|
||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
|
||||
self.post_interpolate_decimate = ConvBnSilu(final_filters, filters, kernel_size=1, silu=False, bn=False)
|
||||
self.interpolate_process = ConvBnSilu(filters, filters)
|
||||
self.interpolate_process2 = ConvBnSilu(filters, filters)
|
||||
tc = transform_count
|
||||
if self.enable_negative_transforms:
|
||||
tc = transform_count * 2
|
||||
self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0)
|
||||
assert filters > transform_count * 2
|
||||
self.final_switch_conv = nn.Conv2d(filters, tc, 1, 1, 0)
|
||||
|
||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||
self.add_noise = add_scalable_noise_to_transforms
|
||||
|
@ -89,10 +91,13 @@ class SwitchComputer(nn.Module):
|
|||
multiplexer = block.forward(multiplexer)
|
||||
for block in self.processing_blocks:
|
||||
multiplexer = block.forward(multiplexer)
|
||||
multiplexer = self.proc_switch_conv(multiplexer)
|
||||
multiplexer = self.final_switch_conv.forward(multiplexer)
|
||||
# Interpolate the multiplexer across the entire shape of the image.
|
||||
|
||||
# Interpolate the multiplexer across the entire shape of the image perform some post-processing before feeding into switch.
|
||||
multiplexer = F.interpolate(multiplexer, size=x.shape[2:], mode='nearest')
|
||||
multiplexer = self.post_interpolate_decimate(multiplexer)
|
||||
multiplexer = self.interpolate_process(multiplexer)
|
||||
multiplexer = self.interpolate_process2(multiplexer)
|
||||
multiplexer = self.final_switch_conv.forward(multiplexer)
|
||||
|
||||
outputs, attention = self.switch(xformed, multiplexer, True)
|
||||
outputs = outputs * self.scale + self.bias
|
||||
|
|
Loading…
Reference in New Issue
Block a user