From a2285ff2eece0e249f734e87435304d463d7e237 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 13 Jul 2020 08:49:09 -0600 Subject: [PATCH] Scale anorm by transform count --- codes/models/archs/SwitchedResidualGenerator_arch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index ea0a167c..0cde385e 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -123,7 +123,7 @@ class ConfigurableSwitchComputer(nn.Module): self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) # And the switch itself, including learned scalars - self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=128)) + self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count)) self.switch_scale = nn.Parameter(torch.full((1,), float(1))) self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)