diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 57b03f28..a5a6c750 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -158,12 +158,13 @@ class DropoutNorm(SwitchNorm): # Compute the dropout probabilities. This module is a no-op before the accumulator is initialized. if self.accumulator_filled > 0: - probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate - bs, br = x.shape[:2] - drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0) - # Ensure that there is always at least one switch left un-dropped out - fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) - drop = drop.logical_or(fix_blank) + with torch.no_grad(): + probs = torch.mean(self.accumulator, dim=0) * self.dropout_rate + bs, br = x.shape[:2] + drop = torch.rand((bs, br), device=x.device) > probs.unsqueeze(0) + # Ensure that there is always at least one switch left un-dropped out + fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) + drop = drop.logical_or(fix_blank) x = drop * x return x @@ -172,7 +173,7 @@ class DropoutNorm(SwitchNorm): class HardRoutingGate(nn.Module): def __init__(self, breadth, dropout_rate=.8): super().__init__() - self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=2) + self.norm = DropoutNorm(breadth, dropout_rate, accumulator_size=128) def forward(self, x): soft = self.norm(nn.functional.softmax(x, dim=1))