Fix dropout norm
This commit is contained in:
parent
438217094c
commit
108c5d829c
|
@ -166,7 +166,8 @@ class DropoutNorm(SwitchNorm):
|
||||||
# Ensure that there is always at least one switch left un-dropped out
|
# Ensure that there is always at least one switch left un-dropped out
|
||||||
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
||||||
drop = drop.logical_or(fix_blank)
|
drop = drop.logical_or(fix_blank)
|
||||||
x = drop * x + ((~drop) * x * self.eps)
|
x_dropped = drop * x + ~drop * -1e20
|
||||||
|
x = x_dropped
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -177,7 +178,7 @@ class HardRoutingGate(nn.Module):
|
||||||
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
soft = self.norm(nn.functional.softmax(x, dim=1))
|
soft = nn.functional.softmax(self.norm(x), dim=1)
|
||||||
return RouteTop1.apply(soft)
|
return RouteTop1.apply(soft)
|
||||||
return soft
|
return soft
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user