This commit is contained in:
James Betker 2021-06-07 11:58:36 -06:00
parent c456a60466
commit f0d4eb9182

View File

@ -166,7 +166,7 @@ class DropoutNorm(SwitchNorm):
# Ensure that there is always at least one switch left un-dropped out
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
drop = drop.logical_or(fix_blank)
x = drop * x + ((not drop) * x * self.eps)
x = drop * x + ((~drop) * x * self.eps)
return x