Amplify dropout rate

This commit is contained in:
James Betker 2021-06-07 15:20:53 -06:00
parent f0d4eb9182
commit 44b09e5f20
2 changed files with 2 additions and 2 deletions

View File

@ -197,7 +197,7 @@ class ResNet(nn.Module):
self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)]) self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)])
self.selector = ResNetTail(block, num_block, num_tails) self.selector = ResNetTail(block, num_block, num_tails)
self.selector_gate = nn.Linear(256, 1) self.selector_gate = nn.Linear(256, 1)
self.gate = HardRoutingGate(num_tails) self.gate = HardRoutingGate(num_tails, dropout_rate=2)
self.final_linear = nn.Linear(256, num_classes) self.final_linear = nn.Linear(256, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride): def _make_layer(self, block, out_channels, num_blocks, stride):

View File

@ -21,7 +21,7 @@ if __name__ == '__main__':
set = TorchDataset(dopt) set = TorchDataset(dopt)
loader = DataLoader(set, num_workers=0, batch_size=32) loader = DataLoader(set, num_workers=0, batch_size=32)
model = ResNet(BasicBlock, [2, 2, 2, 2]) model = ResNet(BasicBlock, [2, 2, 2, 2])
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardsw_85000.pth')) model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth'))
model.eval() model.eval()
bins = [[] for _ in range(8)] bins = [[] for _ in range(8)]