diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index e37b2586..f66abb7b 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -197,7 +197,7 @@ class ResNet(nn.Module): self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)]) self.selector = ResNetTail(block, num_block, num_tails) self.selector_gate = nn.Linear(256, 1) - self.gate = HardRoutingGate(num_tails) + self.gate = HardRoutingGate(num_tails, dropout_rate=2) self.final_linear = nn.Linear(256, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): diff --git a/codes/scripts/cifar100_untangle.py b/codes/scripts/cifar100_untangle.py index 6395417b..1f8f67f3 100644 --- a/codes/scripts/cifar100_untangle.py +++ b/codes/scripts/cifar100_untangle.py @@ -21,7 +21,7 @@ if __name__ == '__main__': set = TorchDataset(dopt) loader = DataLoader(set, num_workers=0, batch_size=32) model = ResNet(BasicBlock, [2, 2, 2, 2]) - model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardsw_85000.pth')) + model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth')) model.eval() bins = [[] for _ in range(8)]