Fix switch norm average
This commit is contained in:
parent
57e1a6a0f2
commit
9a6991e461
|
@ -149,15 +149,16 @@ class SwitchNorm(nn.Module):
|
|||
# Compute the norm factor.
|
||||
if self.accumulator_filled > 0:
|
||||
norm = torch.mean(self.accumulator, dim=0)
|
||||
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
|
||||
else:
|
||||
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
||||
|
||||
norm = norm.view(1,-1)
|
||||
while len(x.shape) < len(norm.shape):
|
||||
norm = norm.unsqueeze(-1)
|
||||
x = x / norm
|
||||
|
||||
# Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in.
|
||||
return x / x.sum(dim=1, keepdim=True)
|
||||
return x
|
||||
|
||||
|
||||
class HardRoutingGate(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user