forked from mrq/DL-Art-School
Add negative_transforms switch to resgen
This commit is contained in:
parent
dfcbe5f2db
commit
0584c3b587
|
@ -74,17 +74,23 @@ def create_sequential_growing_processing_block(filters_init, filter_growth, num_
|
||||||
|
|
||||||
|
|
||||||
class SwitchComputer(nn.Module):
|
class SwitchComputer(nn.Module):
|
||||||
def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0, init_temp=20):
|
def __init__(self, channels_in, filters, growth, transform_block, transform_count, reduction_blocks, processing_blocks=0,
|
||||||
|
init_temp=20, enable_negative_transforms=False):
|
||||||
super(SwitchComputer, self).__init__()
|
super(SwitchComputer, self).__init__()
|
||||||
|
self.enable_negative_transforms = enable_negative_transforms
|
||||||
|
|
||||||
self.filter_conv = ConvBnLelu(channels_in, filters)
|
self.filter_conv = ConvBnLelu(channels_in, filters)
|
||||||
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(filters * 2 ** i) for i in range(reduction_blocks)])
|
||||||
final_filters = filters * 2 ** reduction_blocks
|
final_filters = filters * 2 ** reduction_blocks
|
||||||
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
self.processing_blocks, final_filters = create_sequential_growing_processing_block(final_filters, growth, processing_blocks)
|
||||||
proc_block_filters = max(final_filters // 2, transform_count)
|
proc_block_filters = max(final_filters // 2, transform_count)
|
||||||
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
|
self.proc_switch_conv = ConvBnLelu(final_filters, proc_block_filters, bn=False)
|
||||||
self.final_switch_conv = nn.Conv2d(proc_block_filters, transform_count, 1, 1, 0)
|
tc = transform_count
|
||||||
|
if self.enable_negative_transforms:
|
||||||
|
tc = transform_count * 2
|
||||||
|
self.final_switch_conv = nn.Conv2d(proc_block_filters, tc, 1, 1, 0)
|
||||||
|
|
||||||
self.transforms = nn.ModuleList([transform_block() for i in range(transform_count)])
|
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||||
|
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
self.switch = BareConvSwitch(initial_temperature=init_temp)
|
||||||
|
@ -93,6 +99,8 @@ class SwitchComputer(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, output_attention_weights=False):
|
def forward(self, x, output_attention_weights=False):
|
||||||
xformed = [t.forward(x) for t in self.transforms]
|
xformed = [t.forward(x) for t in self.transforms]
|
||||||
|
if self.enable_negative_transforms:
|
||||||
|
xformed.extend([-t for t in xformed])
|
||||||
|
|
||||||
multiplexer = self.filter_conv(x)
|
multiplexer = self.filter_conv(x)
|
||||||
for block in self.reduction_blocks:
|
for block in self.reduction_blocks:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user