From a0158ebc6915c78a7cf20fd01694dcf57846ee2e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 6 Jun 2021 10:02:24 -0600 Subject: [PATCH] Simplify cifar resnet further for faster training --- .../classifiers/cifar_resnet_branched.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 293ed46f..f5a53719 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -85,11 +85,11 @@ class ResNetTail(nn.Module): def __init__(self, block, num_block, num_classes=100): super().__init__() - self.in_channels = 128 - self.conv4_x = self._make_layer(block, 256, num_block[2], 2) - self.conv5_x = self._make_layer(block, 512, num_block[3], 2) + self.in_channels = 64 + self.conv4_x = self._make_layer(block, 128, num_block[2], 2) + self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -111,19 +111,19 @@ class ResNetTail(nn.Module): class ResNet(nn.Module): - def __init__(self, block, num_block, num_classes=100, num_tails=20): + def __init__(self, block, num_block, num_classes=100, num_tails=8): super().__init__() - self.in_channels = 64 + self.in_channels = 32 self.conv1 = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(64), + nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True)) - self.conv2_x = self._make_layer(block, 64, num_block[0], 1) - self.conv3_x = self._make_layer(block, 128, num_block[1], 2) + self.conv2_x = self._make_layer(block, 32, num_block[0], 1) + self.conv3_x = self._make_layer(block, 64, num_block[1], 2) self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)]) self.selector = ResNetTail(block, num_block, num_tails) - self.final_linear = nn.Linear(256, 100) + self.final_linear = nn.Linear(256, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -181,5 +181,5 @@ def resnet152(): if __name__ == '__main__': model = ResNet(BasicBlock, [2,2,2,2]) - print(model(torch.randn(2,3,32,32), torch.LongTensor([4,19])).shape) + print(model(torch.randn(2,3,32,32), torch.LongTensor([4,7])).shape)