diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 73e9ad09..2b2c73b3 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -149,15 +149,16 @@ class SwitchNorm(nn.Module): # Compute the norm factor. if self.accumulator_filled > 0: norm = torch.mean(self.accumulator, dim=0) + norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here. else: norm = torch.ones(self.group_size, device=self.accumulator.device) + norm = norm.view(1,-1) while len(x.shape) < len(norm.shape): norm = norm.unsqueeze(-1) x = x / norm - # Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in. - return x / x.sum(dim=1, keepdim=True) + return x class HardRoutingGate(nn.Module):