"""resnet in pytorch [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning for Image Recognition https://arxiv.org/abs/1512.03385v1 """ import torch import torch.nn as nn import torch.nn.functional as F from trainer.networks import register_model class BasicBlock(nn.Module): """Basic Block for resnet 18 and resnet 34 """ #BasicBlock and BottleNeck block #have different output size #we use class attribute expansion #to distinct expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() #residual function self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) #shortcut self.shortcut = nn.Sequential() #the shortcut output dimension is not the same with residual function #use 1*1 convolution to match the dimension if stride != 1 or in_channels != BasicBlock.expansion * out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) class BottleNeck(nn.Module): """Residual block for resnet over 50 layers """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion), ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * BottleNeck.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) class ResNet(nn.Module): def __init__(self, block, num_block, num_classes=100): super().__init__() self.in_channels = 32 self.conv1 = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) #we use a different inputsize than the original paper #so conv2_x's stride is 1 self.conv2_x = self._make_layer(block, 32, num_block[0], 1) self.conv3_x = self._make_layer(block, 64, num_block[1], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): """make resnet layers(by layer i didnt mean this 'layer' was the same as a neuron netowork layer, ex. conv layer), one layer may contain more than one residual block Args: block: block type, basic block or bottle neck block out_channels: output depth channel number of this layer num_blocks: how many blocks per layer stride: the stride of the first block of this layer Return: return a resnet layer """ # we have num_block blocks per layer, the first block # could be 1 or 2, other blocks would always be 1 strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): output = self.conv1(x) output = self.conv2_x(output) output = self.conv3_x(output) output = self.conv4_x(output) output = self.conv5_x(output) output = self.avgpool(output) output = output.view(output.size(0), -1) output = self.fc(output) return output class SymbolicLoss: def __init__(self, category_depths=[3,5,5,3], convergence_weighting=[1,.6,.3,.1], divergence_weighting=[.1,.3,.6,1]): self.depths = category_depths self.total_classes = 1 for c in category_depths: self.total_classes *= c self.elements_per_level = [] m = 1 for c in category_depths[1:]: m *= c self.elements_per_level.append(self.total_classes // m) self.elements_per_level = self.elements_per_level + [1] self.convergence_weighting = convergence_weighting self.divergence_weighting = divergence_weighting # TODO: improve the above logic, I'm sure it can be done better. def __call__(self, logits, collaboratorLabels): """ Computes the symbolic loss. :param logits: Nested level scores for the network under training. :param collaboratorLabels: level labels from the collaborator network. :return: Convergence loss & divergence loss. """ b, l = logits.shape assert l == self.total_classes, f"Expected {self.total_classes} predictions, got {l}" convergence_loss = 0 divergence_loss = 0 for epc, cw, dw in zip(self.elements_per_level, self.convergence_weighting, self.divergence_weighting): level_logits = logits.view(b, l//epc, epc) level_logits = level_logits.sum(dim=-1) level_labels = collaboratorLabels.div(epc, rounding_mode='trunc') # Convergence convergence_loss = convergence_loss + F.cross_entropy(level_logits, level_labels) * cw # Divergence div_label_indices = level_logits.argmax(dim=-1) # TODO: find the torch-y way of doing this. dp = [] for bi, i in enumerate(div_label_indices): dp.append(level_logits[:, i]) div_preds = torch.stack(dp, dim=0) div_labels = torch.arange(0, b, device=logits.device) divergence_loss = divergence_loss + F.cross_entropy(div_preds, div_labels) return convergence_loss, divergence_loss if __name__ == '__main__': sl = SymbolicLoss() logits = torch.randn(5, sl.total_classes) labels = torch.randint(0, sl.total_classes, (5,)) sl(logits, labels) class TwinnedCifar(nn.Module): def __init__(self): super().__init__() self.loss = SymbolicLoss() self.netA = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) self.netB = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) def forward(self, x): y1 = self.netA(x) y2 = self.netB(x) b = x.shape[0] convergenceA, divergenceA = self.loss(y1[:b//2], y2.argmax(dim=-1)[:b//2]) convergenceB, divergenceB = self.loss(y2[b//2:], y1.argmax(dim=-1)[b//2:]) return convergenceA + convergenceB, divergenceA + divergenceB @register_model def register_twin_cifar(opt_net, opt): """ return a ResNet 18 object """ return TwinnedCifar() def resnet34(): """ return a ResNet 34 object """ return ResNet(BasicBlock, [3, 4, 6, 3]) def resnet50(): """ return a ResNet 50 object """ return ResNet(BottleNeck, [3, 4, 6, 3]) def resnet101(): """ return a ResNet 101 object """ return ResNet(BottleNeck, [3, 4, 23, 3]) def resnet152(): """ return a ResNet 152 object """ return ResNet(BottleNeck, [3, 8, 36, 3])