forked from mrq/DL-Art-School
Fix switch norm average
This commit is contained in:
parent
57e1a6a0f2
commit
9a6991e461
|
@ -149,15 +149,16 @@ class SwitchNorm(nn.Module):
|
||||||
# Compute the norm factor.
|
# Compute the norm factor.
|
||||||
if self.accumulator_filled > 0:
|
if self.accumulator_filled > 0:
|
||||||
norm = torch.mean(self.accumulator, dim=0)
|
norm = torch.mean(self.accumulator, dim=0)
|
||||||
|
norm = norm * x.shape[1] / norm.sum() # The resulting norm should sum up to the total breadth: we are just re-weighting here.
|
||||||
else:
|
else:
|
||||||
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
norm = torch.ones(self.group_size, device=self.accumulator.device)
|
||||||
|
|
||||||
|
norm = norm.view(1,-1)
|
||||||
while len(x.shape) < len(norm.shape):
|
while len(x.shape) < len(norm.shape):
|
||||||
norm = norm.unsqueeze(-1)
|
norm = norm.unsqueeze(-1)
|
||||||
x = x / norm
|
x = x / norm
|
||||||
|
|
||||||
# Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in.
|
return x
|
||||||
return x / x.sum(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
|
|
||||||
class HardRoutingGate(nn.Module):
|
class HardRoutingGate(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user