Add SPSR_module

This is a port from the SPSR repo, it's going to need a lot of work to be properly integrated
but as of this commit it at least runs.
This commit is contained in:
James Betker 2020-08-01 22:02:54 -06:00
parent f33ed578a2
commit f894ba8f98
10 changed files with 1841 additions and 1 deletions

462
codes/models/SPSR_model.py Normal file
View File

@ -0,0 +1,462 @@
import os
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import models.SPSR_networks as networks
from .base_model import BaseModel
from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss
logger = logging.getLogger('base')
import torch.nn.functional as F
class Get_gradient(nn.Module):
def __init__(self):
super(Get_gradient, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class SPSRModel(BaseModel):
def __init__(self, opt):
super(SPSRModel, self).__init__(opt)
train_opt = opt['train']
# define networks and load pretrained models
self.netG = networks.define_G(opt).to(self.device) # G
if self.is_train:
self.netD = networks.define_D(opt).to(self.device) # D
self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
self.netG.train()
self.netD.train()
self.netD_grad.train()
self.load() # load G and D if needed
# define losses, optimizer and scheduler
if self.is_train:
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
if l_pix_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_w = train_opt['pixel_weight']
else:
logger.info('Remove pixel loss.')
self.cri_pix = None
# G feature loss
if train_opt['feature_weight'] > 0:
l_fea_type = train_opt['feature_criterion']
if l_fea_type == 'l1':
self.cri_fea = nn.L1Loss().to(self.device)
elif l_fea_type == 'l2':
self.cri_fea = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight']
else:
logger.info('Remove feature loss.')
self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_w = train_opt['gan_weight']
# D_update_ratio and D_init_iters are for WGAN
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
# Branch_init_iters
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
if train_opt['gan_type'] == 'wgan-gp':
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
# gradient penalty loss
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
self.l_gp_w = train_opt['gp_weigth']
# gradient_pixel_loss
if train_opt['gradient_pixel_weight'] > 0:
self.cri_pix_grad = nn.MSELoss().to(self.device)
self.l_pix_grad_w = train_opt['gradient_pixel_weight']
else:
self.cri_pix_grad = None
# gradient_gan_loss
if train_opt['gradient_gan_weight'] > 0:
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_grad_w = train_opt['gradient_gan_weight']
else:
self.cri_grad_gan = None
# G_grad pixel loss
if train_opt['pixel_branch_weight'] > 0:
l_pix_type = train_opt['pixel_branch_criterion']
if l_pix_type == 'l1':
self.cri_pix_branch = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix_branch = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_branch_w = train_opt['pixel_branch_weight']
else:
logger.info('Remove G_grad pixel loss.')
self.cri_pix_branch = None
# optimizers
# G
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
for k, v in self.netG.named_parameters(): # optimize part of the model
if v.requires_grad:
optim_params.append(v)
else:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
self.optimizers.append(self.optimizer_G)
# D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
self.optimizers.append(self.optimizer_D)
# D_grad
wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
self.optimizers.append(self.optimizer_D_grad)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
train_opt['lr_steps'], train_opt['lr_gamma']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
self.get_grad = Get_gradient()
self.get_grad_nopadding = Get_gradient_nopadding()
def feed_data(self, data, need_HR=True):
# LR
self.var_L = data['LQ'].to(self.device)
if need_HR: # train or val
self.var_H = data['GT'].to(self.device)
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
def optimize_parameters(self, step):
# G
for p in self.netD.parameters():
p.requires_grad = False
for p in self.netD_grad.parameters():
p.requires_grad = False
if(self.Branch_pretrain):
if(step < self.Branch_init_iters):
for k,v in self.netG.named_parameters():
if 'f_' not in k :
v.requires_grad=False
else:
for k,v in self.netG.named_parameters():
if 'f_' not in k :
v.requires_grad=True
self.optimizer_G.zero_grad()
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
self.fake_H_grad = self.get_grad(self.fake_H)
self.var_H_grad = self.get_grad(self.var_H)
self.var_ref_grad = self.get_grad(self.var_ref)
self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.var_H).detach()
fake_fea = self.netF(self.fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
if self.cri_pix_grad: #gradient pixel loss
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
l_g_total += l_g_pix_grad
if self.cri_pix_branch: #branch pixel loss
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding)
l_g_total += l_g_pix_grad_branch
# G gan + cls loss
pred_g_fake = self.netD(self.fake_H)
pred_d_real = self.netD(self.var_ref).detach()
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan
# grad G gan + cls loss
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
l_g_total += l_g_gan_grad
l_g_total.backward()
self.optimizer_G.step()
# D
for p in self.netD.parameters():
p.requires_grad = True
self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2
if self.opt['train']['gan_type'] == 'wgan-gp':
batch_size = self.var_ref.size(0)
if self.random_pt.size(0) != batch_size:
self.random_pt.resize_(batch_size, 1, 1, 1)
self.random_pt.uniform_() # Draw random interpolation points
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
interp.requires_grad = True
interp_crit, _ = self.netD(interp)
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
l_d_total += l_d_gp
l_d_total.backward()
self.optimizer_D.step()
for p in self.netD_grad.parameters():
p.requires_grad = True
self.optimizer_D_grad.zero_grad()
l_d_total_grad = 0
pred_d_real_grad = self.netD_grad(self.var_ref_grad)
pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad.backward()
self.optimizer_D_grad.step()
# Log sample images from first microbatch.
if step % 50 == 0:
import torchvision.utils as utils
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
# fed_LQ is not chunked.
utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
# set log
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
# G
if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item()
if self.cri_fea:
self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item()
if self.cri_pix_branch: #branch pixel loss
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
# D
self.log_dict['l_d_real'] = l_d_real.item()
self.log_dict['l_d_fake'] = l_d_fake.item()
# D_grad
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
if self.opt['train']['gan_type'] == 'wgan-gp':
self.log_dict['l_d_gp'] = l_d_gp.item()
# D outputs
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
# D_grad outputs
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
self.netG.train()
def get_current_log(self, step):
return self.log_dict
def get_current_visuals(self, need_HR=True):
out_dict = OrderedDict()
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu()
out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu()
if need_HR:
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
return out_dict
def print_network(self):
# Generator
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
if self.is_train:
# Disriminator
s, n = self.get_network_description(self.netD)
if isinstance(self.netD, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
self.netD.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netD.__class__.__name__)
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
if self.cri_fea: # F, Perceptual Network
s, n = self.get_network_description(self.netF)
if isinstance(self.netF, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
self.netF.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netF.__class__.__name__)
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG)
load_path_D = self.opt['path']['pretrain_model_D']
if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD)
def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step)
self.save_network(self.netD_grad, 'D_grad', iter_step)

View File

View File

@ -0,0 +1,654 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from . import block as B
from . import spectral_norm as SN
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
def forward(self, x):
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x_list.append(x_i)
x = torch.cat(x_list, dim = 1)
return x
####################
# Generator
####################
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, self.HR_conv0_new)
self.get_g_nopadding = Get_gradient_nopadding()
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i+5](x)
x_fea2 = x
for i in range(5):
x = block_list[i+10](x)
x_fea3 = x
for i in range(5):
x = block_list[i+15](x)
x_fea4 = x
x = block_list[20:](x)
#short cut
x = x_ori+x
x= self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
#short cut
x_cat_4 = x_cat_4+x_b_fea
x_branch = self.b_module(x_cat_4)
x_out_branch = self.conv_w(x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_grad
####################
# Discriminator
####################
# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 96*96
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 64*64
class Discriminator_VGG_64(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_64, self).__init__()
# features
# hxw, c
# 64, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 32*32
class Discriminator_VGG_32(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_32, self).__init__()
# features
# hxw, c
# 32, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 256
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
# classifier
self.classifier = nn.Sequential(
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 16*16
class Discriminator_VGG_16(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_16, self).__init__()
# features
# hxw, c
# 16, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 128
self.features = B.sequential(conv0, conv1, conv2, conv3)
# classifier
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(nn.Module):
def __init__(self):
super(Discriminator_VGG_128_SN, self).__init__()
# features
# hxw, c
# 128, 64
self.lrelu = nn.LeakyReLU(0.2, True)
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
# 64, 64
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
# 32, 128
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
# 16, 256
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 8, 512
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 4, 512
# classifier
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
def forward(self, x):
x = self.lrelu(self.conv0(x))
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
x = self.lrelu(self.conv6(x))
x = self.lrelu(self.conv7(x))
x = self.lrelu(self.conv8(x))
x = self.lrelu(self.conv9(x))
x = x.view(x.size(0), -1)
x = self.lrelu(self.linear0(x))
x = self.linear1(x)
return x
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class Discriminator_VGG_192(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_192, self).__init__()
# features
# hxw, c
# 192, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 96, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9, conv10, conv11)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
####################
# Perceptual Network
####################
class VGGFeatureExtractor(nn.Module):
def __init__(self,
feature_layer=34,
use_bn=False,
use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
model = torchvision.models.resnet101(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.children())[:8])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class MINCNet(nn.Module):
def __init__(self):
super(MINCNet, self).__init__()
self.ReLU = nn.ReLU(True)
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
def forward(self, x):
out = self.ReLU(self.conv11(x))
out = self.ReLU(self.conv12(out))
out = self.maxpool1(out)
out = self.ReLU(self.conv21(out))
out = self.ReLU(self.conv22(out))
out = self.maxpool2(out)
out = self.ReLU(self.conv31(out))
out = self.ReLU(self.conv32(out))
out = self.ReLU(self.conv33(out))
out = self.maxpool3(out)
out = self.ReLU(self.conv41(out))
out = self.ReLU(self.conv42(out))
out = self.ReLU(self.conv43(out))
out = self.maxpool4(out)
out = self.ReLU(self.conv51(out))
out = self.ReLU(self.conv52(out))
out = self.conv53(out)
return out
class MINCFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
device=torch.device('cpu')):
super(MINCFeatureExtractor, self).__init__()
self.features = MINCNet()
self.features.load_state_dict(
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
self.features.eval()
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
output = self.features(x)
return output

View File

@ -0,0 +1,258 @@
from collections import OrderedDict
import torch
import torch.nn as nn
####################
# Basic blocks
####################
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
# helper selecting activation
# neg_slope: for leakyrelu and init of prelu
# n_prelu: for p_relu num_parameters
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
return layer
def norm(norm_type, nc):
# helper selecting normalization layer
norm_type = norm_type.lower()
if norm_type == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
return layer
def pad(pad_type, padding):
# helper selecting padding layer
# if padding is 'zero', do by conv layers
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
tmpstr = 'Identity .. \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlock(nn.Module):
#Elementwise sum the output of a submodule to its input
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
return x, self.sub
def __repr__(self):
tmpstr = 'Identity + \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
def sequential(*args):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
'''
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
'''
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
dilation=dilation, bias=bias, groups=groups)
a = act(act_type) if act_type else None
if 'CNA' in mode:
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == 'NAC':
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
# Important!
# input----ReLU(inplace)----Conv--+----output
# |________________________|
# inplace ReLU will modify the input, therefore wrong output
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
####################
# Useful blocks
####################
class ResNetBlock(nn.Module):
'''
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
'''
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
super(ResNetBlock, self).__init__()
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
if mode == 'CNA':
act_type = None
if mode == 'CNAC': # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
# if in_nc != out_nc:
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
# None, None)
# print('Need a projecter in ResNetBlock.')
# else:
# self.project = lambda x:x
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
res = self.res(x).mul(self.res_scale)
return x + res
class ResidualDenseBlock_5C(nn.Module):
'''
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x
class RRDB(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x
####################
# Upsampler
####################
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu'):
'''
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
'''
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=None, act_type=None)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
return sequential(upsample, conv)

View File

@ -0,0 +1,60 @@
import torch
import torch.nn as nn
# Define GAN loss: [vanilla | lsgan | wgan-gp]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss
class GradientPenaltyLoss(nn.Module):
def __init__(self, device=torch.device('cpu')):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.grad_outputs = self.grad_outputs.to(device)
def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
self.grad_outputs.resize_(input.size()).fill_(1.0)
return self.grad_outputs
def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
loss = ((grad_interp_norm - 1)**2).mean()
return loss

View File

@ -0,0 +1,94 @@
import random
import torch
import numpy as np
def _get_random_crop_indices(crop_region, crop_size):
'''
crop_region: (strat_y, end_y, start_x, end_x)
crop_size: (y, x)
'''
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
print(region_size, crop_size)
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
if region_size[0] == crop_size[0]:
start_y = crop_region[0]
else:
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
if region_size[1] == crop_size[1]:
start_x = crop_region[2]
else:
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
max_choice = candidates[0]
min_choice = candidates[0]
max_dist = 0
min_dist = np.infty
with torch.no_grad():
for c in candidates:
start_y, end_y, start_x, end_x = c
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
if dist > max_dist:
max_dist = dist
max_choice = c
if dist < min_dist:
min_dist = dist
min_choice = c
if min_diff:
return min_choice
else:
return max_choice
def get_split_list(divisor, dividend):
split_list = [dividend // divisor for _ in range(divisor - 1)]
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
return split_list
def random_sampler(pic_size, crop_dict):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def region_sampler(crop_region, crop_dict):
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
return crop_res_dict
# TODO more flexible
def pyramid_sampler(pic_size, crop_dict):
crop_res_dict = {}
sorted_key = list(crop_dict.keys())
sorted_key.sort(key=lambda x: int(x), reverse=True)
k = sorted_key[0]
crop_size = (int(k), int(k))
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
for i in range(1, len(sorted_key)):
crop_res_dict[sorted_key[i]] = []
afore_num = crop_dict[sorted_key[i-1]]
new_num = crop_dict[sorted_key[i]]
split_list = get_split_list(afore_num, new_num)
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
for j in range(len(split_list)):
crop_region = crop_res_dict[sorted_key[i-1]][j]
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
return crop_res_dict

View File

@ -0,0 +1,149 @@
'''
Copy from pytorch github repo
Spectral Normalization from https://arxiv.org/abs/1802.05957
'''
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
weight_mat = weight
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim])
height = weight_mat.size(0)
weight_mat = weight_mat.reshape(height, -1)
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
sigma = torch.dot(u, torch.matmul(weight_mat, v))
weight = weight / sigma
return weight, u
def remove(self, module):
weight = getattr(module, self.name)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_orig')
module.register_parameter(self.name, torch.nn.Parameter(weight))
def __call__(self, module, inputs):
if module.training:
weight, u = self.compute_weight(module)
setattr(module, self.name, weight)
setattr(module, self.name + '_u', u)
else:
r_g = getattr(module, self.name + '_orig').requires_grad
getattr(module, self.name).detach_().requires_grad_(r_g)
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
height = weight.size(dim)
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a
# buffer, which will cause weight to be included in the state dict
# and also supports nn.init due to shared storage.
module.register_buffer(fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_forward_pre_hook(fn)
return fn
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectal norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is 0, except for modules that are instances of
ConvTranspose1/2/3d, when it is 1
Returns:
The original module with the spectal norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
Linear (20 -> 40)
>>> m.weight_u.size()
torch.Size([20])
"""
if dim is None:
if isinstance(
module,
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))

View File

@ -0,0 +1,161 @@
import functools
import logging
import torch
import torch.nn as nn
from torch.nn import init
import models.SPSR_modules.architecture as arch
logger = logging.getLogger('base')
####################
# initialize
####################
def weights_init_normal(m, std=0.02):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
init.constant_(m.bias.data, 0.0)
def weights_init_kaiming(m, scale=1):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
if m.affine != False:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
# scale for 'kaiming', std for 'normal'.
if init_type == 'normal':
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
net.apply(weights_init_normal_)
elif init_type == 'kaiming':
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
net.apply(weights_init_kaiming_)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
####################
# define network
####################
# Generator
def define_G(opt, device=None):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
if which_model == 'spsr_net':
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
if opt['is_train']:
init_weights(netG, init_type='kaiming', scale=0.1)
return netG
# Discriminator
def define_D(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_D_grad(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_F(opt, use_bn=False):
device = torch.device('cuda')
# pytorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
use_input_norm=True, device=device)
netF.eval()
return netF

View File

@ -11,6 +11,8 @@ def create_model(opt):
from .SRGAN_model import SRGANModel as M from .SRGAN_model import SRGANModel as M
elif model == 'feat': elif model == 'feat':
from .feature_model import FeatureModel as M from .feature_model import FeatureModel as M
if model == 'spsr':
from .SPSR_model import SPSRModel as M
else: else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
m = M(opt) m = M(opt)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)