Fix switch norm average

This commit is contained in:
James Betker 2021-06-06 15:04:28 -06:00
parent 57e1a6a0f2
commit 9a6991e461

View File

@ -149,15 +149,16 @@ class SwitchNorm(nn.Module):
# Compute the norm factor. # Compute the norm factor.
if self.accumulator_filled > 0: if self.accumulator_filled > 0:
norm = torch.mean(self.accumulator, dim=0) norm = torch.mean(self.accumulator, dim=0)
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
else: else:
norm = torch.ones(self.group_size, device=self.accumulator.device) norm = torch.ones(self.group_size, device=self.accumulator.device)
norm = norm.view(1,-1)
while len(x.shape) < len(norm.shape): while len(x.shape) < len(norm.shape):
norm = norm.unsqueeze(-1) norm = norm.unsqueeze(-1)
x = x / norm x = x / norm
# Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in. return x
return x / x.sum(dim=1, keepdim=True)
class HardRoutingGate(nn.Module): class HardRoutingGate(nn.Module):