Another go at fixing nan
This commit is contained in:
parent
1c574c5bd1
commit
c456a60466
|
@ -112,7 +112,7 @@ class ResNetTail(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DropoutNorm(SwitchNorm):
|
class DropoutNorm(SwitchNorm):
|
||||||
def __init__(self, group_size, dropout_rate, accumulator_size=256):
|
def __init__(self, group_size, dropout_rate, accumulator_size=256, eps=1e-6):
|
||||||
super().__init__(group_size, accumulator_size)
|
super().__init__(group_size, accumulator_size)
|
||||||
self.accumulator_desired_size = accumulator_size
|
self.accumulator_desired_size = accumulator_size
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
@ -120,6 +120,7 @@ class DropoutNorm(SwitchNorm):
|
||||||
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
|
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||||
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
|
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||||
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
|
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
def add_norm_to_buffer(self, x):
|
def add_norm_to_buffer(self, x):
|
||||||
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
|
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
|
||||||
|
@ -165,7 +166,7 @@ class DropoutNorm(SwitchNorm):
|
||||||
# Ensure that there is always at least one switch left un-dropped out
|
# Ensure that there is always at least one switch left un-dropped out
|
||||||
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br)
|
||||||
drop = drop.logical_or(fix_blank)
|
drop = drop.logical_or(fix_blank)
|
||||||
x = drop * x
|
x = drop * x + ((not drop) * x * self.eps)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user