From 9a6991e461a5e64f48d77ee0b13c271943c04bcd Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 6 Jun 2021 15:04:28 -0600 Subject: [PATCH] Fix switch norm average --- codes/models/switched_conv/switched_conv_hard_routing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 73e9ad09..2b2c73b3 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -149,15 +149,16 @@ class SwitchNorm(nn.Module): # Compute the norm factor. if self.accumulator_filled > 0: norm = torch.mean(self.accumulator, dim=0) + norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here. else: norm = torch.ones(self.group_size, device=self.accumulator.device) + norm = norm.view(1,-1) while len(x.shape) < len(norm.shape): norm = norm.unsqueeze(-1) x = x / norm - # Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in. - return x / x.sum(dim=1, keepdim=True) + return x class HardRoutingGate(nn.Module):