forked from mrq/DL-Art-School
Attempt to fix nan
This commit is contained in:
parent
eda796985b
commit
1c574c5bd1
|
@ -158,12 +158,13 @@ class DropoutNorm(SwitchNorm):
|
||||||
|
|
||||||
# Compute the dropout probabilities. This module is a no-op before the accumulator is initialized.
|
# Compute the dropout probabilities. This module is a no-op before the accumulator is initialized.
|
||||||
if self.accumulator_filled > 0:
|
if self.accumulator_filled > 0:
|
||||||
probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate
|
with torch.no_grad():
|
||||||
bs, br = x.shape[:2]
|
probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate
|
||||||
drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0)
|
bs, br = x.shape[:2]
|
||||||
# Ensure that there is always at least one switch left un-dropped out
|
drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0)
|
||||||
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
# Ensure that there is always at least one switch left un-dropped out
|
||||||
drop = drop.logical_or(fix_blank)
|
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
||||||
|
drop = drop.logical_or(fix_blank)
|
||||||
x = drop * x
|
x = drop * x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -172,7 +173,7 @@ class DropoutNorm(SwitchNorm):
|
||||||
class HardRoutingGate(nn.Module):
|
class HardRoutingGate(nn.Module):
|
||||||
def __init__(self, breadth, dropout_rate=.8):
|
def __init__(self, breadth, dropout_rate=.8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=2)
|
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
soft = self.norm(nn.functional.softmax(x, dim=1))
|
soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user