Fix dropout norm

This commit is contained in:
James Betker 2021-06-07 16:13:23 -06:00
parent 438217094c
commit 108c5d829c

View File

@ -166,7 +166,8 @@ class DropoutNorm(SwitchNorm):
# Ensure that there is always at least one switch left un-dropped out # Ensure that there is always at least one switch left un-dropped out
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
drop = drop.logical_or(fix_blank) drop = drop.logical_or(fix_blank)
x = drop * x + ((~drop) * x * self.eps) x_dropped = drop * x + ~drop * -1e20
x = x_dropped
return x return x
@ -177,7 +178,7 @@ class HardRoutingGate(nn.Module):
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128) self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
def forward(self, x): def forward(self, x):
soft = self.norm(nn.functional.softmax(x, dim=1)) soft = nn.functional.softmax(self.norm(x), dim=1)
return RouteTop1.apply(soft) return RouteTop1.apply(soft)
return soft return soft