153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
# Resnet implementation that adds a u-net style up-conversion component to output values at a
|
|
# specified pixel density.
|
|
#
|
|
# The downsampling part of the network is compatible with the built-in torch resnet for use in
|
|
# transfer learning.
|
|
#
|
|
# Only resnet50 currently supported.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
import torchvision
|
|
|
|
|
|
from trainer.networks import register_model
|
|
from utils.util import checkpoint, opt_get
|
|
|
|
|
|
class ReverseBottleneck(nn.Module):
|
|
|
|
def __init__(self, inplanes, planes, groups=1, passthrough=False,
|
|
base_width=64, dilation=1, norm_layer=None):
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
width = int(planes * (base_width / 64.)) * groups
|
|
self.passthrough = passthrough
|
|
if passthrough:
|
|
self.integrate = conv1x1(inplanes*2, inplanes)
|
|
self.bn_integrate = norm_layer(inplanes)
|
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
self.conv1 = conv1x1(inplanes, width)
|
|
self.bn1 = norm_layer(width)
|
|
self.conv2 = conv3x3(width, width, groups, dilation)
|
|
self.bn2 = norm_layer(width)
|
|
self.residual_upsample = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
conv1x1(width, width),
|
|
norm_layer(width),
|
|
)
|
|
self.conv3 = conv1x1(width, planes)
|
|
self.bn3 = norm_layer(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.upsample = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
conv1x1(inplanes, planes),
|
|
norm_layer(planes),
|
|
)
|
|
|
|
def forward(self, x, passthrough=None):
|
|
if self.passthrough:
|
|
x = self.bn_integrate(self.integrate(torch.cat([x, passthrough], dim=1)))
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.residual_upsample(out)
|
|
|
|
out = self.conv3(out)
|
|
out = self.bn3(out)
|
|
|
|
identity = self.upsample(x)
|
|
|
|
out = out + identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class UResNet50(torchvision.models.resnet.ResNet):
|
|
|
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
norm_layer=None, out_dim=128):
|
|
super().__init__(block, layers, num_classes, |