diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 3e9b733c..bc114bdd 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -173,14 +173,19 @@ class DropoutNorm(SwitchNorm): class HardRoutingGate(nn.Module): - def __init__(self, breadth, dropout_rate=.8): + def __init__(self, breadth, fade_steps=10000, dropout_rate=.8): super().__init__() self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128) + self.fade_steps = fade_steps + self.register_buffer("last_step", torch.zeros(1, dtype=torch.long, device='cpu')) def forward(self, x): + if self.last_step < self.fade_steps: + x = torch.randn_like(x) * (self.fade_steps - self.last_step) / self.fade_steps + \ + x * self.last_step / self.fade_steps + self.last_step = self.last_step + 1 soft = nn.functional.softmax(self.norm(x), dim=1) return RouteTop1.apply(soft) - return soft class ResNet(nn.Module):