diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 4d1e7894..3e9b733c 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -166,7 +166,8 @@ class DropoutNorm(SwitchNorm): # Ensure that there is always at least one switch left un-dropped out fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) drop = drop.logical_or(fix_blank) - x = drop * x + ((~drop) * x * self.eps) + x_dropped = drop * x + ~drop * -1e20 + x = x_dropped return x @@ -177,7 +178,7 @@ class HardRoutingGate(nn.Module): self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128) def forward(self, x): - soft = self.norm(nn.functional.softmax(x, dim=1)) + soft = nn.functional.softmax(self.norm(x), dim=1) return RouteTop1.apply(soft) return soft