Add fade in for hard switch
This commit is contained in:
parent
108c5d829c
commit
9b5f4abb91
|
@ -173,14 +173,19 @@ class DropoutNorm(SwitchNorm):
|
|||
|
||||
|
||||
class HardRoutingGate(nn.Module):
|
||||
def __init__(self, breadth, dropout_rate=.8):
|
||||
def __init__(self, breadth, fade_steps=10000, dropout_rate=.8):
|
||||
super().__init__()
|
||||
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
||||
self.fade_steps = fade_steps
|
||||
self.register_buffer("last_step", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||
|
||||
def forward(self, x):
|
||||
if self.last_step < self.fade_steps:
|
||||
x = torch.randn_like(x) * (self.fade_steps - self.last_step) / self.fade_steps + \
|
||||
x * self.last_step / self.fade_steps
|
||||
self.last_step = self.last_step + 1
|
||||
soft = nn.functional.softmax(self.norm(x), dim=1)
|
||||
return RouteTop1.apply(soft)
|
||||
return soft
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user