forked from mrq/DL-Art-School
Add fade in for hard switch
This commit is contained in:
parent
108c5d829c
commit
9b5f4abb91
|
@ -173,14 +173,19 @@ class DropoutNorm(SwitchNorm):
|
||||||
|
|
||||||
|
|
||||||
class HardRoutingGate(nn.Module):
|
class HardRoutingGate(nn.Module):
|
||||||
def __init__(self, breadth, dropout_rate=.8):
|
def __init__(self, breadth, fade_steps=10000, dropout_rate=.8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128)
|
||||||
|
self.fade_steps = fade_steps
|
||||||
|
self.register_buffer("last_step", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.last_step < self.fade_steps:
|
||||||
|
x = torch.randn_like(x) * (self.fade_steps - self.last_step) / self.fade_steps + \
|
||||||
|
x * self.last_step / self.fade_steps
|
||||||
|
self.last_step = self.last_step + 1
|
||||||
soft = nn.functional.softmax(self.norm(x), dim=1)
|
soft = nn.functional.softmax(self.norm(x), dim=1)
|
||||||
return RouteTop1.apply(soft)
|
return RouteTop1.apply(soft)
|
||||||
return soft
|
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user