f894ba8f98
This is a port from the SPSR repo, it's going to need a lot of work to be properly integrated but as of this commit it at least runs.
655 lines
26 KiB
Python
655 lines
26 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from . import block as B
|
|
from . import spectral_norm as SN
|
|
|
|
|
|
|
|
class Get_gradient_nopadding(nn.Module):
|
|
def __init__(self):
|
|
super(Get_gradient_nopadding, self).__init__()
|
|
kernel_v = [[0, -1, 0],
|
|
[0, 0, 0],
|
|
[0, 1, 0]]
|
|
kernel_h = [[0, 0, 0],
|
|
[-1, 0, 1],
|
|
[0, 0, 0]]
|
|
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
|
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
|
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
|
|
|
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
|
|
|
|
|
def forward(self, x):
|
|
x_list = []
|
|
for i in range(x.shape[1]):
|
|
x_i = x[:, i]
|
|
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
|
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
|
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
|
x_list.append(x_i)
|
|
|
|
x = torch.cat(x_list, dim = 1)
|
|
|
|
return x
|
|
|
|
|
|
####################
|
|
# Generator
|
|
####################
|
|
|
|
class SPSRNet(nn.Module):
|
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
|
super(SPSRNet, self).__init__()
|
|
|
|
n_upscale = int(math.log(upscale, 2))
|
|
|
|
if upscale == 3:
|
|
n_upscale = 1
|
|
|
|
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
|
|
|
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
|
|
|
if upsample_mode == 'upconv':
|
|
upsample_block = B.upconv_blcok
|
|
elif upsample_mode == 'pixelshuffle':
|
|
upsample_block = B.pixelshuffle_block
|
|
else:
|
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
|
if upscale == 3:
|
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
|
else:
|
|
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
|
|
|
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
|
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
|
*upsampler, self.HR_conv0_new)
|
|
|
|
self.get_g_nopadding = Get_gradient_nopadding()
|
|
|
|
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
|
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
|
|
|
|
|
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
|
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
|
|
|
|
|
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
|
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
|
|
|
|
|
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
|
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
|
|
|
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
|
|
|
if upsample_mode == 'upconv':
|
|
upsample_block = B.upconv_blcok
|
|
elif upsample_mode == 'pixelshuffle':
|
|
upsample_block = B.pixelshuffle_block
|
|
else:
|
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
|
if upscale == 3:
|
|
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
|
else:
|
|
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
|
|
|
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
|
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
|
|
|
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
|
|
|
|
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
|
|
|
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
|
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x_grad = self.get_g_nopadding(x)
|
|
x = self.model[0](x)
|
|
|
|
x, block_list = self.model[1](x)
|
|
|
|
x_ori = x
|
|
for i in range(5):
|
|
x = block_list[i](x)
|
|
x_fea1 = x
|
|
|
|
for i in range(5):
|
|
x = block_list[i+5](x)
|
|
x_fea2 = x
|
|
|
|
for i in range(5):
|
|
x = block_list[i+10](x)
|
|
x_fea3 = x
|
|
|
|
for i in range(5):
|
|
x = block_list[i+15](x)
|
|
x_fea4 = x
|
|
|
|
x = block_list[20:](x)
|
|
#short cut
|
|
x = x_ori+x
|
|
x= self.model[2:](x)
|
|
x = self.HR_conv1_new(x)
|
|
|
|
x_b_fea = self.b_fea_conv(x_grad)
|
|
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
|
|
|
x_cat_1 = self.b_block_1(x_cat_1)
|
|
x_cat_1 = self.b_concat_1(x_cat_1)
|
|
|
|
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
|
|
|
x_cat_2 = self.b_block_2(x_cat_2)
|
|
x_cat_2 = self.b_concat_2(x_cat_2)
|
|
|
|
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
|
|
|
x_cat_3 = self.b_block_3(x_cat_3)
|
|
x_cat_3 = self.b_concat_3(x_cat_3)
|
|
|
|
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
|
|
|
x_cat_4 = self.b_block_4(x_cat_4)
|
|
x_cat_4 = self.b_concat_4(x_cat_4)
|
|
|
|
x_cat_4 = self.b_LR_conv(x_cat_4)
|
|
|
|
#short cut
|
|
x_cat_4 = x_cat_4+x_b_fea
|
|
x_branch = self.b_module(x_cat_4)
|
|
|
|
x_out_branch = self.conv_w(x_branch)
|
|
########
|
|
x_branch_d = x_branch
|
|
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
|
x_f_cat = self.f_block(x_f_cat)
|
|
x_out = self.f_concat(x_f_cat)
|
|
x_out = self.f_HR_conv0(x_out)
|
|
x_out = self.f_HR_conv1(x_out)
|
|
|
|
#########
|
|
return x_out_branch, x_out, x_grad
|
|
|
|
|
|
####################
|
|
# Discriminator
|
|
####################
|
|
|
|
|
|
# VGG style Discriminator with input size 128*128
|
|
class Discriminator_VGG_128(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_128, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 128, 64
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 64, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 32, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 16, 256
|
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 8, 512
|
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 4, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
|
conv9)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# VGG style Discriminator with input size 96*96
|
|
class Discriminator_VGG_96(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_96, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 96, 3
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 48, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 24, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 12, 256
|
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 6, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# VGG style Discriminator with input size 64*64
|
|
class Discriminator_VGG_64(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_64, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 64, 3
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 32, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 16, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 8, 256
|
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 4, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# VGG style Discriminator with input size 32*32
|
|
class Discriminator_VGG_32(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_32, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 32, 3
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 16, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 8, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 4, 256
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# VGG style Discriminator with input size 16*16
|
|
class Discriminator_VGG_16(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_16, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 16, 3
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 8, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 4, 128
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
# VGG style Discriminator with input size 128*128, Spectral Normalization
|
|
class Discriminator_VGG_128_SN(nn.Module):
|
|
def __init__(self):
|
|
super(Discriminator_VGG_128_SN, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 128, 64
|
|
self.lrelu = nn.LeakyReLU(0.2, True)
|
|
|
|
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
|
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
|
# 64, 64
|
|
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
|
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
|
# 32, 128
|
|
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
|
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
|
# 16, 256
|
|
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
|
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
|
# 8, 512
|
|
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
|
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
|
# 4, 512
|
|
|
|
# classifier
|
|
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
|
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.lrelu(self.conv0(x))
|
|
x = self.lrelu(self.conv1(x))
|
|
x = self.lrelu(self.conv2(x))
|
|
x = self.lrelu(self.conv3(x))
|
|
x = self.lrelu(self.conv4(x))
|
|
x = self.lrelu(self.conv5(x))
|
|
x = self.lrelu(self.conv6(x))
|
|
x = self.lrelu(self.conv7(x))
|
|
x = self.lrelu(self.conv8(x))
|
|
x = self.lrelu(self.conv9(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.lrelu(self.linear0(x))
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
|
|
class Discriminator_VGG_96(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_96, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 96, 64
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 48, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 24, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 12, 256
|
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 6, 512
|
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 3, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
|
conv9)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
class Discriminator_VGG_192(nn.Module):
|
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
|
super(Discriminator_VGG_192, self).__init__()
|
|
# features
|
|
# hxw, c
|
|
# 192, 64
|
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
|
mode=mode)
|
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 96, 64
|
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 48, 128
|
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 24, 256
|
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 12, 512
|
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 6, 512
|
|
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
|
act_type=act_type, mode=mode)
|
|
# 3, 512
|
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
|
conv9, conv10, conv11)
|
|
|
|
# classifier
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
####################
|
|
# Perceptual Network
|
|
####################
|
|
|
|
class VGGFeatureExtractor(nn.Module):
|
|
def __init__(self,
|
|
feature_layer=34,
|
|
use_bn=False,
|
|
use_input_norm=True,
|
|
device=torch.device('cpu')):
|
|
super(VGGFeatureExtractor, self).__init__()
|
|
if use_bn:
|
|
model = torchvision.models.vgg19_bn(pretrained=True)
|
|
else:
|
|
model = torchvision.models.vgg19(pretrained=True)
|
|
self.use_input_norm = use_input_norm
|
|
if self.use_input_norm:
|
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
|
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
|
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
|
self.register_buffer('mean', mean)
|
|
self.register_buffer('std', std)
|
|
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
|
# No need to BP to variable
|
|
for k, v in self.features.named_parameters():
|
|
v.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
output = self.features(x)
|
|
return output
|
|
|
|
|
|
class ResNet101FeatureExtractor(nn.Module):
|
|
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
|
super(ResNet101FeatureExtractor, self).__init__()
|
|
model = torchvision.models.resnet101(pretrained=True)
|
|
self.use_input_norm = use_input_norm
|
|
if self.use_input_norm:
|
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
|
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
|
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
|
self.register_buffer('mean', mean)
|
|
self.register_buffer('std', std)
|
|
self.features = nn.Sequential(*list(model.children())[:8])
|
|
# No need to BP to variable
|
|
for k, v in self.features.named_parameters():
|
|
v.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
output = self.features(x)
|
|
return output
|
|
|
|
|
|
class MINCNet(nn.Module):
|
|
def __init__(self):
|
|
super(MINCNet, self).__init__()
|
|
self.ReLU = nn.ReLU(True)
|
|
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
|
|
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
|
|
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
|
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
|
|
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
|
|
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
|
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
|
|
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
|
|
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
|
|
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
|
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
|
|
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
|
|
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
|
|
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
|
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
|
|
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
|
|
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
|
|
|
|
def forward(self, x):
|
|
out = self.ReLU(self.conv11(x))
|
|
out = self.ReLU(self.conv12(out))
|
|
out = self.maxpool1(out)
|
|
out = self.ReLU(self.conv21(out))
|
|
out = self.ReLU(self.conv22(out))
|
|
out = self.maxpool2(out)
|
|
out = self.ReLU(self.conv31(out))
|
|
out = self.ReLU(self.conv32(out))
|
|
out = self.ReLU(self.conv33(out))
|
|
out = self.maxpool3(out)
|
|
out = self.ReLU(self.conv41(out))
|
|
out = self.ReLU(self.conv42(out))
|
|
out = self.ReLU(self.conv43(out))
|
|
out = self.maxpool4(out)
|
|
out = self.ReLU(self.conv51(out))
|
|
out = self.ReLU(self.conv52(out))
|
|
out = self.conv53(out)
|
|
return out
|
|
|
|
|
|
class MINCFeatureExtractor(nn.Module):
|
|
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
|
|
device=torch.device('cpu')):
|
|
super(MINCFeatureExtractor, self).__init__()
|
|
|
|
self.features = MINCNet()
|
|
self.features.load_state_dict(
|
|
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
|
|
self.features.eval()
|
|
# No need to BP to variable
|
|
for k, v in self.features.named_parameters():
|
|
v.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
output = self.features(x)
|
|
return output
|