import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision from . import block as B class Get_gradient_nopadding(nn.Module): def __init__(self): super(Get_gradient_nopadding, self).__init__() kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]] kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]] kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) def forward(self, x): x_list = [] for i in range(x.shape[1]): x_i = x[:, i] x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) x_list.append(x_i) x = torch.cat(x_list, dim = 1) return x #################### # Generator #################### class SPSRNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): super(SPSRNet, self).__init__() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) if upscale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ *upsampler, self.HR_conv0_new) self.get_g_nopadding = Get_gradient_nopadding() self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None) self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) if upscale == 3: b_upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None) self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None) self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None) self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) def forward(self, x): x_grad = self.get_g_nopadding(x) x = self.model[0](x) x, block_list = self.model[1](x) x_ori = x for i in range(5): x = block_list[i](x) x_fea1 = x for i in range(5): x = block_list[i+5](x) x_fea2 = x for i in range(5): x = block_list[i+10](x) x_fea3 = x for i in range(5): x = block_list[i+15](x) x_fea4 = x x = block_list[20:](x) #short cut x = x_ori+x x= self.model[2:](x) x = self.HR_conv1_new(x) x_b_fea = self.b_fea_conv(x_grad) x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) x_cat_1 = self.b_block_1(x_cat_1) x_cat_1 = self.b_concat_1(x_cat_1) x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) x_cat_2 = self.b_block_2(x_cat_2) x_cat_2 = self.b_concat_2(x_cat_2) x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) x_cat_3 = self.b_block_3(x_cat_3) x_cat_3 = self.b_concat_3(x_cat_3) x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) x_cat_4 = self.b_block_4(x_cat_4) x_cat_4 = self.b_concat_4(x_cat_4) x_cat_4 = self.b_LR_conv(x_cat_4) #short cut x_cat_4 = x_cat_4+x_b_fea x_branch = self.b_module(x_cat_4) x_out_branch = self.conv_w(x_branch) ######## x_branch_d = x_branch x_f_cat = torch.cat([x_branch_d, x], dim=1) x_f_cat = self.f_block(x_f_cat) x_out = self.f_concat(x_f_cat) x_out = self.f_HR_conv0(x_out) x_out = self.f_HR_conv1(x_out) ######### return x_out_branch, x_out, x_grad #################### # Discriminator #################### # VGG style Discriminator with input size 128*128 class Discriminator_VGG_128(nn.Module): def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'): super(Discriminator_VGG_128, self).__init__() # features # hxw, c # 128, 64 conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ mode=mode) conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ act_type=act_type, mode=mode) # 64, 64 conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ act_type=act_type, mode=mode) conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ act_type=act_type, mode=mode) # 32, 128 conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ act_type=act_type, mode=mode) conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ act_type=act_type, mode=mode) # 16, 256 conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ act_type=act_type, mode=mode) conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ act_type=act_type, mode=mode) # 8, 512 conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ act_type=act_type, mode=mode) conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ act_type=act_type, mode=mode) # 4, 512 self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ conv9) # classifier self.classifier = nn.Sequential( nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x #################### # Perceptual Network #################### class VGGFeatureExtractor(nn.Module): def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, device=torch.device('cpu')): super(VGGFeatureExtractor, self).__init__() if use_bn: model = torchvision.models.vgg19_bn(pretrained=True) else: model = torchvision.models.vgg19(pretrained=True) self.use_input_norm = use_input_norm if self.use_input_norm: mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] self.register_buffer('mean', mean) self.register_buffer('std', std) self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) # No need to BP to variable for k, v in self.features.named_parameters(): v.requires_grad = False def forward(self, x): if self.use_input_norm: x = (x - self.mean) / self.std output = self.features(x) return output class ResNet101FeatureExtractor(nn.Module): def __init__(self, use_input_norm=True, device=torch.device('cpu')): super(ResNet101FeatureExtractor, self).__init__() model = torchvision.models.resnet101(pretrained=True) self.use_input_norm = use_input_norm if self.use_input_norm: mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] self.register_buffer('mean', mean) self.register_buffer('std', std) self.features = nn.Sequential(*list(model.children())[:8]) # No need to BP to variable for k, v in self.features.named_parameters(): v.requires_grad = False def forward(self, x): if self.use_input_norm: x = (x - self.mean) / self.std output = self.features(x) return output class MINCNet(nn.Module): def __init__(self): super(MINCNet, self).__init__() self.ReLU = nn.ReLU(True) self.conv11 = nn.Conv2d(3, 64, 3, 1, 1) self.conv12 = nn.Conv2d(64, 64, 3, 1, 1) self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) self.conv21 = nn.Conv2d(64, 128, 3, 1, 1) self.conv22 = nn.Conv2d(128, 128, 3, 1, 1) self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) self.conv31 = nn.Conv2d(128, 256, 3, 1, 1) self.conv32 = nn.Conv2d(256, 256, 3, 1, 1) self.conv33 = nn.Conv2d(256, 256, 3, 1, 1) self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) self.conv41 = nn.Conv2d(256, 512, 3, 1, 1) self.conv42 = nn.Conv2d(512, 512, 3, 1, 1) self.conv43 = nn.Conv2d(512, 512, 3, 1, 1) self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True) self.conv51 = nn.Conv2d(512, 512, 3, 1, 1) self.conv52 = nn.Conv2d(512, 512, 3, 1, 1) self.conv53 = nn.Conv2d(512, 512, 3, 1, 1) def forward(self, x): out = self.ReLU(self.conv11(x)) out = self.ReLU(self.conv12(out)) out = self.maxpool1(out) out = self.ReLU(self.conv21(out)) out = self.ReLU(self.conv22(out)) out = self.maxpool2(out) out = self.ReLU(self.conv31(out)) out = self.ReLU(self.conv32(out)) out = self.ReLU(self.conv33(out)) out = self.maxpool3(out) out = self.ReLU(self.conv41(out)) out = self.ReLU(self.conv42(out)) out = self.ReLU(self.conv43(out)) out = self.maxpool4(out) out = self.ReLU(self.conv51(out)) out = self.ReLU(self.conv52(out)) out = self.conv53(out) return out class MINCFeatureExtractor(nn.Module): def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \ device=torch.device('cpu')): super(MINCFeatureExtractor, self).__init__() self.features = MINCNet() self.features.load_state_dict( torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True) self.features.eval() # No need to BP to variable for k, v in self.features.named_parameters(): v.requires_grad = False def forward(self, x): output = self.features(x) return output