Set post_transform_block to None where applicable
This commit is contained in:
parent
6f8705e8cb
commit
c4543ce124
|
@ -99,8 +99,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
# And the switch itself, including learned scalars
|
# And the switch itself, including learned scalars
|
||||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=anorm_multiplier * transform_count) if attention_norm else None)
|
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=anorm_multiplier * transform_count) if attention_norm else None)
|
||||||
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
||||||
if post_transform_block is not None:
|
self.post_transform_block = post_transform_block
|
||||||
self.post_transform_block = post_transform_block
|
|
||||||
if post_switch_conv:
|
if post_switch_conv:
|
||||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user