# A direct copy of torchvision's resnet.py modified to support gradient checkpointing. import torch import torch.nn as nn from torchvision.models.resnet import BasicBlock, Bottleneck from torchvision.models.utils import load_state_dict_from_url import torchvision __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] from trainer.networks import register_model from utils.util import checkpoint model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', } class Backbone(torchvision.models.resnet.ResNet): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer) del self.fc del self.avgpool def _forward_impl(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) l1 = checkpoint(self.layer1, x) l2 = checkpoint(self.layer2, l1) l3 = checkpoint(self.layer3, l2) l4 = checkpoint(self.layer4, l3) return l1, l2, l3, l4 def forward(self, x): return self._forward_impl(x) def _backbone(arch, block, layers, pretrained, progress, **kwargs): model = Backbone(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def backbone18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def backbone34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def backbone50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def backbone101(pretrained=False, progress=True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def backbone152(pretrained=False, progress=True, **kwargs): r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) @register_model def register_resnet50(opt_net, opt): model = resnet50(pretrained=opt_net['pretrained']) if opt_net['custom_head_logits']: model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) return model