forked from mrq/DL-Art-School
d1007ccfe7
- The pixpro latent now rescales the latent space instead of using a "coordinate vector", which **might** have performance implications. - The latent against which the pixel loss is computed can now be a small, randomly sampled patch out of the entire latent, allowing further memory/computational discounts. Since the loss computation does not have a receptive field, this should not alter the loss. - The instance projection size can now be separate from the pixel projection size. - PixContrast removed entirely. - ResUnet with full resolution added.
88 lines
3.5 KiB
Python
88 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1, conv3x3
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
import torchvision
|
|
|
|
from models.arch_util import ConvBnRelu
|
|
from models.pixel_level_contrastive_learning.resnet_unet import ReverseBottleneck
|
|
from trainer.networks import register_model
|
|
from utils.util import checkpoint, opt_get
|
|
|
|
|
|
class UResNet50_2(torchvision.models.resnet.ResNet):
|
|
|
|
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
norm_layer=None, out_dim=128):
|
|
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
|
|
replace_stride_with_dilation, norm_layer)
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
self.level_conv = ConvBnRelu(3, 64)
|
|
'''
|
|
# For reference:
|
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
dilate=replace_stride_with_dilation[0])
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
dilate=replace_stride_with_dilation[1])
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
dilate=replace_stride_with_dilation[2])
|
|
'''
|
|
uplayers = []
|
|
inplanes = 2048
|
|
first = True
|
|
div = [2,2,2,4,1]
|
|
for i in range(5):
|
|
uplayers.append(ReverseBottleneck(inplanes, inplanes // div[i], norm_layer=norm_layer, passthrough=not first))
|
|
inplanes = inplanes // div[i]
|
|
first = False
|
|
self.uplayers = nn.ModuleList(uplayers)
|
|
self.tail = nn.Sequential(conv3x3(128, 64),
|
|
norm_layer(64),
|
|
nn.ReLU(),
|
|
conv1x1(64, out_dim))
|
|
|
|
del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
|
|
|
|
|
|
def _forward_impl(self, x):
|
|
level = self.level_conv(x)
|
|
x0 = self.relu(self.bn1(self.conv1(x)))
|
|
x = self.maxpool(x0)
|
|
|
|
x1 = checkpoint(self.layer1, x)
|
|
x2 = checkpoint(self.layer2, x1)
|
|
x3 = checkpoint(self.layer3, x2)
|
|
x4 = checkpoint(self.layer4, x3)
|
|
unused = self.avgpool(x4) # This is performed for instance-level pixpro learning, even though it is unused.
|
|
|
|
x = checkpoint(self.uplayers[0], x4)
|
|
x = checkpoint(self.uplayers[1], x, x3)
|
|
x = checkpoint(self.uplayers[2], x, x2)
|
|
x = checkpoint(self.uplayers[3], x, x1)
|
|
x = checkpoint(self.uplayers[4], x, x0)
|
|
|
|
return checkpoint(self.tail, torch.cat([x, level], dim=1))
|
|
|
|
def forward(self, x):
|
|
return self._forward_impl(x)
|
|
|
|
|
|
@register_model
|
|
def register_u_resnet50_2(opt_net, opt):
|
|
model = UResNet50_2(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
|
|
if opt_get(opt_net, ['use_pretrained_base'], False):
|
|
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=True)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = UResNet50_2(Bottleneck, [3,4,6,3])
|
|
samp = torch.rand(1,3,224,224)
|
|
y = model(samp)
|
|
print(y.shape)
|
|
# For pixpro: attach to "tail.3"
|