Add fade in for hard switch

This commit is contained in:
James Betker 2021-06-07 18:15:09 -06:00
parent 108c5d829c
commit 9b5f4abb91

View File

@ -173,14 +173,19 @@ class DropoutNorm(SwitchNorm):
class HardRoutingGate(nn.Module):
def __init__(self, breadth, dropout_rate=.8):
def __init__(self, breadth, fade_steps=10000, dropout_rate=.8):
super().__init__()
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
self.fade_steps = fade_steps
self.register_buffer("last_step", torch.zeros(1, dtype=torch.long, device='cpu'))
def forward(self, x):
if self.last_step < self.fade_steps:
x = torch.randn_like(x) * (self.fade_steps - self.last_step) / self.fade_steps + \
x * self.last_step / self.fade_steps
self.last_step = self.last_step + 1
soft = nn.functional.softmax(self.norm(x), dim=1)
return RouteTop1.apply(soft)
return soft
class ResNet(nn.Module):