This is a port from the SPSR repo, it's going to need a lot of work to be properly integrated but as of this commit it at least runs.
655 lines
26 KiB
655 lines
26 KiB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from . import block as B
from . import spectral_norm as SN
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
def forward(self, x):
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x = torch.cat(x_list, dim = 1)
return x
# Generator
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, self.HR_conv0_new)
self.get_g_nopadding = Get_gradient_nopadding()
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i+5](x)
x_fea2 = x
for i in range(5):
x = block_list[i+10](x)
x_fea3 = x
for i in range(5):
x = block_list[i+15](x)
x_fea4 = x
x = block_list[20:](x)
#short cut
x = x_ori+x
x= self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
#short cut
x_cat_4 = x_cat_4+x_b_fea
x_branch = self.b_module(x_cat_4)
x_out_branch = self.conv_w(x_branch)
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
return x_out_branch, x_out, x_grad
# Discriminator
# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 96*96
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 64*64
class Discriminator_VGG_64(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_64, self).__init__()
# features
# hxw, c
# 64, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 32*32
class Discriminator_VGG_32(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_32, self).__init__()
# features
# hxw, c
# 32, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 256
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
# classifier
self.classifier = nn.Sequential(
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 16*16
class Discriminator_VGG_16(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_16, self).__init__()
# features
# hxw, c
# 16, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 128
self.features = B.sequential(conv0, conv1, conv2, conv3)
# classifier
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(nn.Module):
def __init__(self):
super(Discriminator_VGG_128_SN, self).__init__()
# features
# hxw, c
# 128, 64
self.lrelu = nn.LeakyReLU(0.2, True)
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
# 64, 64
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
# 32, 128
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
# 16, 256
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 8, 512
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 4, 512
# classifier
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
def forward(self, x):
x = self.lrelu(self.conv0(x))
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
x = self.lrelu(self.conv6(x))
x = self.lrelu(self.conv7(x))
x = self.lrelu(self.conv8(x))
x = self.lrelu(self.conv9(x))
x = x.view(x.size(0), -1)
x = self.lrelu(self.linear0(x))
x = self.linear1(x)
return x
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class Discriminator_VGG_192(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_192, self).__init__()
# features
# hxw, c
# 192, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 96, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9, conv10, conv11)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# Perceptual Network
class VGGFeatureExtractor(nn.Module):
def __init__(self,
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
model = torchvision.models.resnet101(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.children())[:8])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class MINCNet(nn.Module):
def __init__(self):
super(MINCNet, self).__init__()
self.ReLU = nn.ReLU(True)
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
def forward(self, x):
out = self.ReLU(self.conv11(x))
out = self.ReLU(self.conv12(out))
out = self.maxpool1(out)
out = self.ReLU(self.conv21(out))
out = self.ReLU(self.conv22(out))
out = self.maxpool2(out)
out = self.ReLU(self.conv31(out))
out = self.ReLU(self.conv32(out))
out = self.ReLU(self.conv33(out))
out = self.maxpool3(out)
out = self.ReLU(self.conv41(out))
out = self.ReLU(self.conv42(out))
out = self.ReLU(self.conv43(out))
out = self.maxpool4(out)
out = self.ReLU(self.conv51(out))
out = self.ReLU(self.conv52(out))
out = self.conv53(out)
return out
class MINCFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
super(MINCFeatureExtractor, self).__init__()
self.features = MINCNet()
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
output = self.features(x)
return output