Amplify dropout rate
This commit is contained in:
parent
f0d4eb9182
commit
44b09e5f20
|
@ -197,7 +197,7 @@ class ResNet(nn.Module):
|
|||
self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)])
|
||||
self.selector = ResNetTail(block, num_block, num_tails)
|
||||
self.selector_gate = nn.Linear(256, 1)
|
||||
self.gate = HardRoutingGate(num_tails)
|
||||
self.gate = HardRoutingGate(num_tails, dropout_rate=2)
|
||||
self.final_linear = nn.Linear(256, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
|
|
|
@ -21,7 +21,7 @@ if __name__ == '__main__':
|
|||
set = TorchDataset(dopt)
|
||||
loader = DataLoader(set, num_workers=0, batch_size=32)
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardsw_85000.pth'))
|
||||
model.load_state_dict(torch.load('C:\\Users\\jbetk\\Downloads\\cifar_hardw_10000.pth'))
|
||||
model.eval()
|
||||
|
||||
bins = [[] for _ in range(8)]
|
||||
|
|
Loading…
Reference in New Issue
Block a user