124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
|
import torchvision
|
|
|
|
|
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
|
'wide_resnet50_2', 'wide_resnet101_2']
|
|
|
|
from trainer.networks import register_model
|
|
from utils.util import checkpoint
|
|
|
|
model_urls = {
|
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
|
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
|
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
|
}
|
|
|
|
|
|
class Backbone(torchvision.models.resnet.ResNet):
|
|
|
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
norm_layer=None):
|
|
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
|
replace_stride_with_dilation, norm_layer)
|
|
del self.fc
|
|
del self.avgpool
|
|
|
|
def _forward_impl(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
l1 = checkpoint(self.layer1, x)
|
|
l2 = checkpoint(self.layer2, l1)
|
|
l3 = checkpoint(self.layer3, l2)
|
|
l4 = checkpoint(self.layer4, l3)
|
|
|
|
return l1, l2, l3, l4
|
|
|
|
def forward(self, x):
|
|
return self._forward_impl(x)
|
|
|
|
|
|
def _backbone(arch, block, layers, pretrained, progress, **kwargs):
|
|
model = Backbone(block, layers, **kwargs)
|
|
if pretrained:
|
|
state_dict = load_state_dict_from_url(model_urls[arch],
|
|
progress=progress)
|
|
model.load_state_dict(state_dict)
|
|
return model
|
|
|
|
|
|
def backbone18(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-18 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _backbone('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
|
**kwargs)
|
|
|
|
|
|
def backbone34(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-34 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _backbone('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
|
**kwargs)
|
|
|
|
|
|
def backbone50(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-50 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _backbone('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
|
**kwargs)
|
|
|
|
|
|
def backbone101(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-101 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _backbone('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
|
**kwargs)
|
|
|
|
|
|
def backbone152(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-152 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
|
**kwargs)
|
|
|