Attempt to fix nan

This commit is contained in:
James Betker 2021-06-07 11:43:42 -06:00
parent eda796985b
commit 1c574c5bd1

View File

@ -158,12 +158,13 @@ class DropoutNorm(SwitchNorm):
# Compute the dropout probabilities. This module is a no-op before the accumulator is initialized. # Compute the dropout probabilities. This module is a no-op before the accumulator is initialized.
if self.accumulator_filled > 0: if self.accumulator_filled > 0:
probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate with torch.no_grad():
bs, br = x.shape[:2] probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate
drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0) bs, br = x.shape[:2]
# Ensure that there is always at least one switch left un-dropped out drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0)
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) # Ensure that there is always at least one switch left un-dropped out
drop = drop.logical_or(fix_blank) fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
drop = drop.logical_or(fix_blank)
x = drop * x x = drop * x
return x return x
@ -172,7 +173,7 @@ class DropoutNorm(SwitchNorm):
class HardRoutingGate(nn.Module): class HardRoutingGate(nn.Module):
def __init__(self, breadth, dropout_rate=.8): def __init__(self, breadth, dropout_rate=.8):
super().__init__() super().__init__()
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=2) self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
def forward(self, x): def forward(self, x):
soft = self.norm(nn.functional.softmax(x, dim=1)) soft = self.norm(nn.functional.softmax(x, dim=1))