Scale anorm by transform count
This commit is contained in:
parent
dd0bbd9a7c
commit
a2285ff2ee
|
@ -123,7 +123,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
|||
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||
|
||||
# And the switch itself, including learned scalars
|
||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=128))
|
||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count))
|
||||
self.switch_scale = nn.Parameter(torch.full((1,), float(1)))
|
||||
self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True)
|
||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||
|
|
Loading…
Reference in New Issue
Block a user