Fix switch norm average

This commit is contained in:
James Betker 2021-06-06 15:04:28 -06:00
parent 57e1a6a0f2
commit 9a6991e461

View File

@ -149,15 +149,16 @@ class SwitchNorm(nn.Module):
# Compute the norm factor.
if self.accumulator_filled > 0:
norm = torch.mean(self.accumulator, dim=0)
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
else:
norm = torch.ones(self.group_size, device=self.accumulator.device)
norm = norm.view(1,-1)
while len(x.shape) < len(norm.shape):
norm = norm.unsqueeze(-1)
x = x / norm
# Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in.
return x / x.sum(dim=1, keepdim=True)
return x
class HardRoutingGate(nn.Module):