forked from mrq/DL-Art-School
Add SPSR_module
This is a port from the SPSR repo, it's going to need a lot of work to be properly integrated but as of this commit it at least runs.
This commit is contained in:
parent
f33ed578a2
commit
f894ba8f98
462
codes/models/SPSR_model.py
Normal file
462
codes/models/SPSR_model.py
Normal file
|
@ -0,0 +1,462 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
import models.SPSR_networks as networks
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class Get_gradient(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Get_gradient, self).__init__()
|
||||||
|
kernel_v = [[0, -1, 0],
|
||||||
|
[0, 0, 0],
|
||||||
|
[0, 1, 0]]
|
||||||
|
kernel_h = [[0, 0, 0],
|
||||||
|
[-1, 0, 1],
|
||||||
|
[0, 0, 0]]
|
||||||
|
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||||
|
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||||
|
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||||
|
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x0 = x[:, 0]
|
||||||
|
x1 = x[:, 1]
|
||||||
|
x2 = x[:, 2]
|
||||||
|
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
|
||||||
|
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
|
||||||
|
|
||||||
|
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
|
||||||
|
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
|
||||||
|
|
||||||
|
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
|
||||||
|
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
|
||||||
|
|
||||||
|
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||||
|
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||||
|
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||||
|
|
||||||
|
x = torch.cat([x0, x1, x2], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Get_gradient_nopadding(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Get_gradient_nopadding, self).__init__()
|
||||||
|
kernel_v = [[0, -1, 0],
|
||||||
|
[0, 0, 0],
|
||||||
|
[0, 1, 0]]
|
||||||
|
kernel_h = [[0, 0, 0],
|
||||||
|
[-1, 0, 1],
|
||||||
|
[0, 0, 0]]
|
||||||
|
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||||
|
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||||
|
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||||
|
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x0 = x[:, 0]
|
||||||
|
x1 = x[:, 1]
|
||||||
|
x2 = x[:, 2]
|
||||||
|
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1)
|
||||||
|
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1)
|
||||||
|
|
||||||
|
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1)
|
||||||
|
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1)
|
||||||
|
|
||||||
|
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1)
|
||||||
|
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1)
|
||||||
|
|
||||||
|
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||||
|
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||||
|
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||||
|
|
||||||
|
x = torch.cat([x0, x1, x2], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SPSRModel(BaseModel):
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(SPSRModel, self).__init__(opt)
|
||||||
|
train_opt = opt['train']
|
||||||
|
|
||||||
|
# define networks and load pretrained models
|
||||||
|
self.netG = networks.define_G(opt).to(self.device) # G
|
||||||
|
if self.is_train:
|
||||||
|
self.netD = networks.define_D(opt).to(self.device) # D
|
||||||
|
self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
|
||||||
|
self.netG.train()
|
||||||
|
self.netD.train()
|
||||||
|
self.netD_grad.train()
|
||||||
|
self.load() # load G and D if needed
|
||||||
|
|
||||||
|
# define losses, optimizer and scheduler
|
||||||
|
if self.is_train:
|
||||||
|
# G pixel loss
|
||||||
|
if train_opt['pixel_weight'] > 0:
|
||||||
|
l_pix_type = train_opt['pixel_criterion']
|
||||||
|
if l_pix_type == 'l1':
|
||||||
|
self.cri_pix = nn.L1Loss().to(self.device)
|
||||||
|
elif l_pix_type == 'l2':
|
||||||
|
self.cri_pix = nn.MSELoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||||
|
self.l_pix_w = train_opt['pixel_weight']
|
||||||
|
else:
|
||||||
|
logger.info('Remove pixel loss.')
|
||||||
|
self.cri_pix = None
|
||||||
|
|
||||||
|
# G feature loss
|
||||||
|
if train_opt['feature_weight'] > 0:
|
||||||
|
l_fea_type = train_opt['feature_criterion']
|
||||||
|
if l_fea_type == 'l1':
|
||||||
|
self.cri_fea = nn.L1Loss().to(self.device)
|
||||||
|
elif l_fea_type == 'l2':
|
||||||
|
self.cri_fea = nn.MSELoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||||
|
self.l_fea_w = train_opt['feature_weight']
|
||||||
|
else:
|
||||||
|
logger.info('Remove feature loss.')
|
||||||
|
self.cri_fea = None
|
||||||
|
if self.cri_fea: # load VGG perceptual loss
|
||||||
|
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||||
|
|
||||||
|
# GD gan loss
|
||||||
|
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||||
|
self.l_gan_w = train_opt['gan_weight']
|
||||||
|
# D_update_ratio and D_init_iters are for WGAN
|
||||||
|
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||||
|
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||||
|
# Branch_init_iters
|
||||||
|
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
|
||||||
|
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
|
||||||
|
|
||||||
|
if train_opt['gan_type'] == 'wgan-gp':
|
||||||
|
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
||||||
|
# gradient penalty loss
|
||||||
|
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
||||||
|
self.l_gp_w = train_opt['gp_weigth']
|
||||||
|
|
||||||
|
# gradient_pixel_loss
|
||||||
|
if train_opt['gradient_pixel_weight'] > 0:
|
||||||
|
self.cri_pix_grad = nn.MSELoss().to(self.device)
|
||||||
|
self.l_pix_grad_w = train_opt['gradient_pixel_weight']
|
||||||
|
else:
|
||||||
|
self.cri_pix_grad = None
|
||||||
|
|
||||||
|
# gradient_gan_loss
|
||||||
|
if train_opt['gradient_gan_weight'] > 0:
|
||||||
|
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||||
|
self.l_gan_grad_w = train_opt['gradient_gan_weight']
|
||||||
|
else:
|
||||||
|
self.cri_grad_gan = None
|
||||||
|
|
||||||
|
# G_grad pixel loss
|
||||||
|
if train_opt['pixel_branch_weight'] > 0:
|
||||||
|
l_pix_type = train_opt['pixel_branch_criterion']
|
||||||
|
if l_pix_type == 'l1':
|
||||||
|
self.cri_pix_branch = nn.L1Loss().to(self.device)
|
||||||
|
elif l_pix_type == 'l2':
|
||||||
|
self.cri_pix_branch = nn.MSELoss().to(self.device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||||
|
self.l_pix_branch_w = train_opt['pixel_branch_weight']
|
||||||
|
else:
|
||||||
|
logger.info('Remove G_grad pixel loss.')
|
||||||
|
self.cri_pix_branch = None
|
||||||
|
|
||||||
|
# optimizers
|
||||||
|
# G
|
||||||
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||||
|
|
||||||
|
optim_params = []
|
||||||
|
for k, v in self.netG.named_parameters(): # optimize part of the model
|
||||||
|
|
||||||
|
if v.requires_grad:
|
||||||
|
optim_params.append(v)
|
||||||
|
else:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
|
||||||
|
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
|
||||||
|
# D
|
||||||
|
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||||
|
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
||||||
|
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||||
|
|
||||||
|
self.optimizers.append(self.optimizer_D)
|
||||||
|
|
||||||
|
# D_grad
|
||||||
|
wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||||
|
self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
|
||||||
|
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||||
|
|
||||||
|
self.optimizers.append(self.optimizer_D_grad)
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
||||||
|
train_opt['lr_steps'], train_opt['lr_gamma']))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||||
|
|
||||||
|
self.log_dict = OrderedDict()
|
||||||
|
self.get_grad = Get_gradient()
|
||||||
|
self.get_grad_nopadding = Get_gradient_nopadding()
|
||||||
|
|
||||||
|
def feed_data(self, data, need_HR=True):
|
||||||
|
# LR
|
||||||
|
self.var_L = data['LQ'].to(self.device)
|
||||||
|
|
||||||
|
if need_HR: # train or val
|
||||||
|
self.var_H = data['GT'].to(self.device)
|
||||||
|
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||||
|
self.var_ref = input_ref.to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_parameters(self, step):
|
||||||
|
# G
|
||||||
|
for p in self.netD.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
for p in self.netD_grad.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
if(self.Branch_pretrain):
|
||||||
|
if(step < self.Branch_init_iters):
|
||||||
|
for k,v in self.netG.named_parameters():
|
||||||
|
if 'f_' not in k :
|
||||||
|
v.requires_grad=False
|
||||||
|
else:
|
||||||
|
for k,v in self.netG.named_parameters():
|
||||||
|
if 'f_' not in k :
|
||||||
|
v.requires_grad=True
|
||||||
|
|
||||||
|
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
|
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||||
|
|
||||||
|
|
||||||
|
self.fake_H_grad = self.get_grad(self.fake_H)
|
||||||
|
self.var_H_grad = self.get_grad(self.var_H)
|
||||||
|
self.var_ref_grad = self.get_grad(self.var_ref)
|
||||||
|
self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
|
||||||
|
|
||||||
|
|
||||||
|
l_g_total = 0
|
||||||
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
if self.cri_pix: # pixel loss
|
||||||
|
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
||||||
|
l_g_total += l_g_pix
|
||||||
|
if self.cri_fea: # feature loss
|
||||||
|
real_fea = self.netF(self.var_H).detach()
|
||||||
|
fake_fea = self.netF(self.fake_H)
|
||||||
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
|
l_g_total += l_g_fea
|
||||||
|
|
||||||
|
if self.cri_pix_grad: #gradient pixel loss
|
||||||
|
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
|
||||||
|
l_g_total += l_g_pix_grad
|
||||||
|
|
||||||
|
|
||||||
|
if self.cri_pix_branch: #branch pixel loss
|
||||||
|
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding)
|
||||||
|
l_g_total += l_g_pix_grad_branch
|
||||||
|
|
||||||
|
|
||||||
|
# G gan + cls loss
|
||||||
|
pred_g_fake = self.netD(self.fake_H)
|
||||||
|
pred_d_real = self.netD(self.var_ref).detach()
|
||||||
|
|
||||||
|
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
|
||||||
|
# grad G gan + cls loss
|
||||||
|
|
||||||
|
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
|
||||||
|
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
|
||||||
|
|
||||||
|
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||||
|
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||||
|
l_g_total += l_g_gan_grad
|
||||||
|
|
||||||
|
|
||||||
|
l_g_total.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
|
||||||
|
# D
|
||||||
|
for p in self.netD.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
self.optimizer_D.zero_grad()
|
||||||
|
l_d_total = 0
|
||||||
|
pred_d_real = self.netD(self.var_ref)
|
||||||
|
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||||
|
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||||
|
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||||
|
|
||||||
|
l_d_total = (l_d_real + l_d_fake) / 2
|
||||||
|
|
||||||
|
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||||
|
batch_size = self.var_ref.size(0)
|
||||||
|
if self.random_pt.size(0) != batch_size:
|
||||||
|
self.random_pt.resize_(batch_size, 1, 1, 1)
|
||||||
|
self.random_pt.uniform_() # Draw random interpolation points
|
||||||
|
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
|
||||||
|
interp.requires_grad = True
|
||||||
|
interp_crit, _ = self.netD(interp)
|
||||||
|
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
|
||||||
|
l_d_total += l_d_gp
|
||||||
|
|
||||||
|
l_d_total.backward()
|
||||||
|
|
||||||
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
|
||||||
|
for p in self.netD_grad.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
self.optimizer_D_grad.zero_grad()
|
||||||
|
l_d_total_grad = 0
|
||||||
|
|
||||||
|
|
||||||
|
pred_d_real_grad = self.netD_grad(self.var_ref_grad)
|
||||||
|
pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G
|
||||||
|
|
||||||
|
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||||
|
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||||
|
|
||||||
|
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||||
|
|
||||||
|
|
||||||
|
l_d_total_grad.backward()
|
||||||
|
|
||||||
|
self.optimizer_D_grad.step()
|
||||||
|
|
||||||
|
# Log sample images from first microbatch.
|
||||||
|
if step % 50 == 0:
|
||||||
|
import torchvision.utils as utils
|
||||||
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||||
|
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||||
|
# fed_LQ is not chunked.
|
||||||
|
utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||||
|
utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||||
|
utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||||
|
|
||||||
|
|
||||||
|
# set log
|
||||||
|
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||||
|
# G
|
||||||
|
if self.cri_pix:
|
||||||
|
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||||
|
if self.cri_fea:
|
||||||
|
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||||
|
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||||
|
|
||||||
|
if self.cri_pix_branch: #branch pixel loss
|
||||||
|
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
|
||||||
|
|
||||||
|
# D
|
||||||
|
self.log_dict['l_d_real'] = l_d_real.item()
|
||||||
|
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||||
|
|
||||||
|
# D_grad
|
||||||
|
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
|
||||||
|
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
|
||||||
|
|
||||||
|
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||||
|
self.log_dict['l_d_gp'] = l_d_gp.item()
|
||||||
|
# D outputs
|
||||||
|
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||||
|
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||||
|
|
||||||
|
# D_grad outputs
|
||||||
|
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
|
||||||
|
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
self.netG.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||||
|
|
||||||
|
self.netG.train()
|
||||||
|
|
||||||
|
def get_current_log(self, step):
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def get_current_visuals(self, need_HR=True):
|
||||||
|
out_dict = OrderedDict()
|
||||||
|
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
||||||
|
|
||||||
|
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
||||||
|
out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu()
|
||||||
|
out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu()
|
||||||
|
if need_HR:
|
||||||
|
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def print_network(self):
|
||||||
|
# Generator
|
||||||
|
s, n = self.get_network_description(self.netG)
|
||||||
|
if isinstance(self.netG, nn.DataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||||
|
self.netG.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||||
|
|
||||||
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
if self.is_train:
|
||||||
|
# Disriminator
|
||||||
|
s, n = self.get_network_description(self.netD)
|
||||||
|
if isinstance(self.netD, nn.DataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
||||||
|
self.netD.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
||||||
|
|
||||||
|
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
if self.cri_fea: # F, Perceptual Network
|
||||||
|
s, n = self.get_network_description(self.netF)
|
||||||
|
if isinstance(self.netF, nn.DataParallel):
|
||||||
|
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
||||||
|
self.netF.module.__class__.__name__)
|
||||||
|
else:
|
||||||
|
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
||||||
|
|
||||||
|
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||||
|
logger.info(s)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
load_path_G = self.opt['path']['pretrain_model_G']
|
||||||
|
if load_path_G is not None:
|
||||||
|
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
||||||
|
self.load_network(load_path_G, self.netG)
|
||||||
|
load_path_D = self.opt['path']['pretrain_model_D']
|
||||||
|
if self.opt['is_train'] and load_path_D is not None:
|
||||||
|
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
||||||
|
self.load_network(load_path_D, self.netD)
|
||||||
|
|
||||||
|
def save(self, iter_step):
|
||||||
|
self.save_network(self.netG, 'G', iter_step)
|
||||||
|
self.save_network(self.netD, 'D', iter_step)
|
||||||
|
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
0
codes/models/SPSR_modules/__init__.py
Normal file
0
codes/models/SPSR_modules/__init__.py
Normal file
654
codes/models/SPSR_modules/architecture.py
Normal file
654
codes/models/SPSR_modules/architecture.py
Normal file
|
@ -0,0 +1,654 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
from . import block as B
|
||||||
|
from . import spectral_norm as SN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Get_gradient_nopadding(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Get_gradient_nopadding, self).__init__()
|
||||||
|
kernel_v = [[0, -1, 0],
|
||||||
|
[0, 0, 0],
|
||||||
|
[0, 1, 0]]
|
||||||
|
kernel_h = [[0, 0, 0],
|
||||||
|
[-1, 0, 1],
|
||||||
|
[0, 0, 0]]
|
||||||
|
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||||
|
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||||
|
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
||||||
|
|
||||||
|
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_list = []
|
||||||
|
for i in range(x.shape[1]):
|
||||||
|
x_i = x[:, i]
|
||||||
|
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
||||||
|
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
||||||
|
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
||||||
|
x_list.append(x_i)
|
||||||
|
|
||||||
|
x = torch.cat(x_list, dim = 1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Generator
|
||||||
|
####################
|
||||||
|
|
||||||
|
class SPSRNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||||
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||||
|
super(SPSRNet, self).__init__()
|
||||||
|
|
||||||
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
|
||||||
|
if upscale == 3:
|
||||||
|
n_upscale = 1
|
||||||
|
|
||||||
|
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
||||||
|
|
||||||
|
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||||
|
|
||||||
|
if upsample_mode == 'upconv':
|
||||||
|
upsample_block = B.upconv_blcok
|
||||||
|
elif upsample_mode == 'pixelshuffle':
|
||||||
|
upsample_block = B.pixelshuffle_block
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||||
|
if upscale == 3:
|
||||||
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||||
|
else:
|
||||||
|
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||||
|
|
||||||
|
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||||
|
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
||||||
|
*upsampler, self.HR_conv0_new)
|
||||||
|
|
||||||
|
self.get_g_nopadding = Get_gradient_nopadding()
|
||||||
|
|
||||||
|
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||||
|
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||||
|
|
||||||
|
|
||||||
|
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||||
|
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||||
|
|
||||||
|
|
||||||
|
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||||
|
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||||
|
|
||||||
|
|
||||||
|
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||||
|
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||||
|
|
||||||
|
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||||
|
|
||||||
|
if upsample_mode == 'upconv':
|
||||||
|
upsample_block = B.upconv_blcok
|
||||||
|
elif upsample_mode == 'pixelshuffle':
|
||||||
|
upsample_block = B.pixelshuffle_block
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||||
|
if upscale == 3:
|
||||||
|
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||||
|
else:
|
||||||
|
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||||
|
|
||||||
|
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||||
|
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||||
|
|
||||||
|
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||||
|
|
||||||
|
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||||
|
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x_grad = self.get_g_nopadding(x)
|
||||||
|
x = self.model[0](x)
|
||||||
|
|
||||||
|
x, block_list = self.model[1](x)
|
||||||
|
|
||||||
|
x_ori = x
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i](x)
|
||||||
|
x_fea1 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i+5](x)
|
||||||
|
x_fea2 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i+10](x)
|
||||||
|
x_fea3 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i+15](x)
|
||||||
|
x_fea4 = x
|
||||||
|
|
||||||
|
x = block_list[20:](x)
|
||||||
|
#short cut
|
||||||
|
x = x_ori+x
|
||||||
|
x= self.model[2:](x)
|
||||||
|
x = self.HR_conv1_new(x)
|
||||||
|
|
||||||
|
x_b_fea = self.b_fea_conv(x_grad)
|
||||||
|
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||||
|
|
||||||
|
x_cat_1 = self.b_block_1(x_cat_1)
|
||||||
|
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||||
|
|
||||||
|
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||||
|
|
||||||
|
x_cat_2 = self.b_block_2(x_cat_2)
|
||||||
|
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||||
|
|
||||||
|
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||||
|
|
||||||
|
x_cat_3 = self.b_block_3(x_cat_3)
|
||||||
|
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||||
|
|
||||||
|
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||||
|
|
||||||
|
x_cat_4 = self.b_block_4(x_cat_4)
|
||||||
|
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||||
|
|
||||||
|
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||||
|
|
||||||
|
#short cut
|
||||||
|
x_cat_4 = x_cat_4+x_b_fea
|
||||||
|
x_branch = self.b_module(x_cat_4)
|
||||||
|
|
||||||
|
x_out_branch = self.conv_w(x_branch)
|
||||||
|
########
|
||||||
|
x_branch_d = x_branch
|
||||||
|
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||||
|
x_f_cat = self.f_block(x_f_cat)
|
||||||
|
x_out = self.f_concat(x_f_cat)
|
||||||
|
x_out = self.f_HR_conv0(x_out)
|
||||||
|
x_out = self.f_HR_conv1(x_out)
|
||||||
|
|
||||||
|
#########
|
||||||
|
return x_out_branch, x_out, x_grad
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Discriminator
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 128*128
|
||||||
|
class Discriminator_VGG_128(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_128, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 128, 64
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 64, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 32, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 16, 256
|
||||||
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 8, 512
|
||||||
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 4, 512
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||||
|
conv9)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 96*96
|
||||||
|
class Discriminator_VGG_96(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_96, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 96, 3
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 48, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 24, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 12, 256
|
||||||
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 6, 512
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 64*64
|
||||||
|
class Discriminator_VGG_64(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_64, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 64, 3
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 32, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 16, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 8, 256
|
||||||
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 4, 512
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 32*32
|
||||||
|
class Discriminator_VGG_32(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_32, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 32, 3
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 16, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 8, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 4, 256
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 16*16
|
||||||
|
class Discriminator_VGG_16(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_16, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 16, 3
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 8, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 4, 128
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# VGG style Discriminator with input size 128*128, Spectral Normalization
|
||||||
|
class Discriminator_VGG_128_SN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Discriminator_VGG_128_SN, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 128, 64
|
||||||
|
self.lrelu = nn.LeakyReLU(0.2, True)
|
||||||
|
|
||||||
|
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
||||||
|
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
||||||
|
# 64, 64
|
||||||
|
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
||||||
|
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
||||||
|
# 32, 128
|
||||||
|
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
||||||
|
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
||||||
|
# 16, 256
|
||||||
|
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
||||||
|
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||||
|
# 8, 512
|
||||||
|
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
||||||
|
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||||
|
# 4, 512
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
||||||
|
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.lrelu(self.conv0(x))
|
||||||
|
x = self.lrelu(self.conv1(x))
|
||||||
|
x = self.lrelu(self.conv2(x))
|
||||||
|
x = self.lrelu(self.conv3(x))
|
||||||
|
x = self.lrelu(self.conv4(x))
|
||||||
|
x = self.lrelu(self.conv5(x))
|
||||||
|
x = self.lrelu(self.conv6(x))
|
||||||
|
x = self.lrelu(self.conv7(x))
|
||||||
|
x = self.lrelu(self.conv8(x))
|
||||||
|
x = self.lrelu(self.conv9(x))
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.lrelu(self.linear0(x))
|
||||||
|
x = self.linear1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator_VGG_96(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_96, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 96, 64
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 48, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 24, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 12, 256
|
||||||
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 6, 512
|
||||||
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 3, 512
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||||
|
conv9)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator_VGG_192(nn.Module):
|
||||||
|
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(Discriminator_VGG_192, self).__init__()
|
||||||
|
# features
|
||||||
|
# hxw, c
|
||||||
|
# 192, 64
|
||||||
|
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||||
|
mode=mode)
|
||||||
|
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 96, 64
|
||||||
|
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 48, 128
|
||||||
|
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 24, 256
|
||||||
|
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 12, 512
|
||||||
|
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 6, 512
|
||||||
|
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||||
|
act_type=act_type, mode=mode)
|
||||||
|
# 3, 512
|
||||||
|
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||||
|
conv9, conv10, conv11)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Perceptual Network
|
||||||
|
####################
|
||||||
|
|
||||||
|
class VGGFeatureExtractor(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
feature_layer=34,
|
||||||
|
use_bn=False,
|
||||||
|
use_input_norm=True,
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
super(VGGFeatureExtractor, self).__init__()
|
||||||
|
if use_bn:
|
||||||
|
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||||
|
else:
|
||||||
|
model = torchvision.models.vgg19(pretrained=True)
|
||||||
|
self.use_input_norm = use_input_norm
|
||||||
|
if self.use_input_norm:
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||||
|
self.register_buffer('mean', mean)
|
||||||
|
self.register_buffer('std', std)
|
||||||
|
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||||
|
# No need to BP to variable
|
||||||
|
for k, v in self.features.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_input_norm:
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
output = self.features(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet101FeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||||
|
super(ResNet101FeatureExtractor, self).__init__()
|
||||||
|
model = torchvision.models.resnet101(pretrained=True)
|
||||||
|
self.use_input_norm = use_input_norm
|
||||||
|
if self.use_input_norm:
|
||||||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||||
|
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||||
|
self.register_buffer('mean', mean)
|
||||||
|
self.register_buffer('std', std)
|
||||||
|
self.features = nn.Sequential(*list(model.children())[:8])
|
||||||
|
# No need to BP to variable
|
||||||
|
for k, v in self.features.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_input_norm:
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
output = self.features(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MINCNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MINCNet, self).__init__()
|
||||||
|
self.ReLU = nn.ReLU(True)
|
||||||
|
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
|
||||||
|
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
|
||||||
|
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
|
||||||
|
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
|
||||||
|
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
|
||||||
|
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||||
|
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||||
|
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
|
||||||
|
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||||
|
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||||
|
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||||
|
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||||
|
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.ReLU(self.conv11(x))
|
||||||
|
out = self.ReLU(self.conv12(out))
|
||||||
|
out = self.maxpool1(out)
|
||||||
|
out = self.ReLU(self.conv21(out))
|
||||||
|
out = self.ReLU(self.conv22(out))
|
||||||
|
out = self.maxpool2(out)
|
||||||
|
out = self.ReLU(self.conv31(out))
|
||||||
|
out = self.ReLU(self.conv32(out))
|
||||||
|
out = self.ReLU(self.conv33(out))
|
||||||
|
out = self.maxpool3(out)
|
||||||
|
out = self.ReLU(self.conv41(out))
|
||||||
|
out = self.ReLU(self.conv42(out))
|
||||||
|
out = self.ReLU(self.conv43(out))
|
||||||
|
out = self.maxpool4(out)
|
||||||
|
out = self.ReLU(self.conv51(out))
|
||||||
|
out = self.ReLU(self.conv52(out))
|
||||||
|
out = self.conv53(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MINCFeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
|
||||||
|
device=torch.device('cpu')):
|
||||||
|
super(MINCFeatureExtractor, self).__init__()
|
||||||
|
|
||||||
|
self.features = MINCNet()
|
||||||
|
self.features.load_state_dict(
|
||||||
|
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
|
||||||
|
self.features.eval()
|
||||||
|
# No need to BP to variable
|
||||||
|
for k, v in self.features.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.features(x)
|
||||||
|
return output
|
258
codes/models/SPSR_modules/block.py
Normal file
258
codes/models/SPSR_modules/block.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Basic blocks
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
||||||
|
# helper selecting activation
|
||||||
|
# neg_slope: for leakyrelu and init of prelu
|
||||||
|
# n_prelu: for p_relu num_parameters
|
||||||
|
act_type = act_type.lower()
|
||||||
|
if act_type == 'relu':
|
||||||
|
layer = nn.ReLU(inplace)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def norm(norm_type, nc):
|
||||||
|
# helper selecting normalization layer
|
||||||
|
norm_type = norm_type.lower()
|
||||||
|
if norm_type == 'batch':
|
||||||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||||||
|
elif norm_type == 'instance':
|
||||||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def pad(pad_type, padding):
|
||||||
|
# helper selecting padding layer
|
||||||
|
# if padding is 'zero', do by conv layers
|
||||||
|
pad_type = pad_type.lower()
|
||||||
|
if padding == 0:
|
||||||
|
return None
|
||||||
|
if pad_type == 'reflect':
|
||||||
|
layer = nn.ReflectionPad2d(padding)
|
||||||
|
elif pad_type == 'replicate':
|
||||||
|
layer = nn.ReplicationPad2d(padding)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_padding(kernel_size, dilation):
|
||||||
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatBlock(nn.Module):
|
||||||
|
# Concat the output of a submodule to its input
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super(ConcatBlock, self).__init__()
|
||||||
|
self.sub = submodule
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = torch.cat((x, self.sub(x)), dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
tmpstr = 'Identity .. \n|'
|
||||||
|
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||||
|
tmpstr = tmpstr + modstr
|
||||||
|
return tmpstr
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutBlock(nn.Module):
|
||||||
|
#Elementwise sum the output of a submodule to its input
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super(ShortcutBlock, self).__init__()
|
||||||
|
self.sub = submodule
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x, self.sub
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
tmpstr = 'Identity + \n|'
|
||||||
|
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||||
|
tmpstr = tmpstr + modstr
|
||||||
|
return tmpstr
|
||||||
|
|
||||||
|
|
||||||
|
def sequential(*args):
|
||||||
|
# Flatten Sequential. It unwraps nn.Sequential.
|
||||||
|
if len(args) == 1:
|
||||||
|
if isinstance(args[0], OrderedDict):
|
||||||
|
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||||
|
return args[0] # No sequential is needed.
|
||||||
|
modules = []
|
||||||
|
for module in args:
|
||||||
|
if isinstance(module, nn.Sequential):
|
||||||
|
for submodule in module.children():
|
||||||
|
modules.append(submodule)
|
||||||
|
elif isinstance(module, nn.Module):
|
||||||
|
modules.append(module)
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
||||||
|
'''
|
||||||
|
Conv layer with padding, normalization, activation
|
||||||
|
mode: CNA --> Conv -> Norm -> Act
|
||||||
|
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||||
|
'''
|
||||||
|
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||||
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
|
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
if 'CNA' in mode:
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
return sequential(p, c, n, a)
|
||||||
|
elif mode == 'NAC':
|
||||||
|
if norm_type is None and act_type is not None:
|
||||||
|
a = act(act_type, inplace=False)
|
||||||
|
# Important!
|
||||||
|
# input----ReLU(inplace)----Conv--+----output
|
||||||
|
# |________________________|
|
||||||
|
# inplace ReLU will modify the input, therefore wrong output
|
||||||
|
n = norm(norm_type, in_nc) if norm_type else None
|
||||||
|
return sequential(n, a, p, c)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Useful blocks
|
||||||
|
####################
|
||||||
|
|
||||||
|
class ResNetBlock(nn.Module):
|
||||||
|
'''
|
||||||
|
ResNet Block, 3-3 style
|
||||||
|
with extra residual scaling used in EDSR
|
||||||
|
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
||||||
|
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
||||||
|
super(ResNetBlock, self).__init__()
|
||||||
|
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||||
|
norm_type, act_type, mode)
|
||||||
|
if mode == 'CNA':
|
||||||
|
act_type = None
|
||||||
|
if mode == 'CNAC': # Residual path: |-CNAC-|
|
||||||
|
act_type = None
|
||||||
|
norm_type = None
|
||||||
|
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||||
|
norm_type, act_type, mode)
|
||||||
|
# if in_nc != out_nc:
|
||||||
|
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
||||||
|
# None, None)
|
||||||
|
# print('Need a projecter in ResNetBlock.')
|
||||||
|
# else:
|
||||||
|
# self.project = lambda x:x
|
||||||
|
self.res = sequential(conv0, conv1)
|
||||||
|
self.res_scale = res_scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.res(x).mul(self.res_scale)
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
'''
|
||||||
|
Residual Dense Block
|
||||||
|
style: 5 convs
|
||||||
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
# gc: growth channel, i.e. intermediate channels
|
||||||
|
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||||
|
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||||
|
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||||
|
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||||
|
if mode == 'CNA':
|
||||||
|
last_act = None
|
||||||
|
else:
|
||||||
|
last_act = act_type
|
||||||
|
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
||||||
|
norm_type=norm_type, act_type=last_act, mode=mode)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||||
|
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||||
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
return x5.mul(0.2) + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDB(nn.Module):
|
||||||
|
'''
|
||||||
|
Residual in Residual Dense Block
|
||||||
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||||
|
super(RRDB, self).__init__()
|
||||||
|
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||||
|
norm_type, act_type, mode)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||||
|
norm_type, act_type, mode)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||||
|
norm_type, act_type, mode)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.RDB1(x)
|
||||||
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
return out.mul(0.2) + x
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Upsampler
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu'):
|
||||||
|
'''
|
||||||
|
Pixel shuffle layer
|
||||||
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||||
|
Neural Network, CVPR17)
|
||||||
|
'''
|
||||||
|
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
|
||||||
|
pad_type=pad_type, norm_type=None, act_type=None)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
return sequential(conv, pixel_shuffle, n, a)
|
||||||
|
|
||||||
|
|
||||||
|
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
||||||
|
# Up conv
|
||||||
|
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||||
|
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||||||
|
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
|
||||||
|
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
||||||
|
return sequential(upsample, conv)
|
60
codes/models/SPSR_modules/loss.py
Normal file
60
codes/models/SPSR_modules/loss.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
||||||
|
class GANLoss(nn.Module):
|
||||||
|
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||||
|
super(GANLoss, self).__init__()
|
||||||
|
self.gan_type = gan_type.lower()
|
||||||
|
self.real_label_val = real_label_val
|
||||||
|
self.fake_label_val = fake_label_val
|
||||||
|
|
||||||
|
if self.gan_type == 'vanilla':
|
||||||
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
|
elif self.gan_type == 'lsgan':
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
elif self.gan_type == 'wgan-gp':
|
||||||
|
|
||||||
|
def wgan_loss(input, target):
|
||||||
|
# target is boolean
|
||||||
|
return -1 * input.mean() if target else input.mean()
|
||||||
|
|
||||||
|
self.loss = wgan_loss
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||||
|
|
||||||
|
def get_target_label(self, input, target_is_real):
|
||||||
|
if self.gan_type == 'wgan-gp':
|
||||||
|
return target_is_real
|
||||||
|
if target_is_real:
|
||||||
|
return torch.empty_like(input).fill_(self.real_label_val)
|
||||||
|
else:
|
||||||
|
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||||
|
|
||||||
|
def forward(self, input, target_is_real):
|
||||||
|
target_label = self.get_target_label(input, target_is_real)
|
||||||
|
loss = self.loss(input, target_label)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class GradientPenaltyLoss(nn.Module):
|
||||||
|
def __init__(self, device=torch.device('cpu')):
|
||||||
|
super(GradientPenaltyLoss, self).__init__()
|
||||||
|
self.register_buffer('grad_outputs', torch.Tensor())
|
||||||
|
self.grad_outputs = self.grad_outputs.to(device)
|
||||||
|
|
||||||
|
def get_grad_outputs(self, input):
|
||||||
|
if self.grad_outputs.size() != input.size():
|
||||||
|
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
||||||
|
return self.grad_outputs
|
||||||
|
|
||||||
|
def forward(self, interp, interp_crit):
|
||||||
|
grad_outputs = self.get_grad_outputs(interp_crit)
|
||||||
|
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
|
||||||
|
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||||
|
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
||||||
|
grad_interp_norm = grad_interp.norm(2, dim=1)
|
||||||
|
|
||||||
|
loss = ((grad_interp_norm - 1)**2).mean()
|
||||||
|
return loss
|
94
codes/models/SPSR_modules/sampler.py
Normal file
94
codes/models/SPSR_modules/sampler.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def _get_random_crop_indices(crop_region, crop_size):
|
||||||
|
'''
|
||||||
|
crop_region: (strat_y, end_y, start_x, end_x)
|
||||||
|
crop_size: (y, x)
|
||||||
|
'''
|
||||||
|
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
|
||||||
|
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
|
||||||
|
print(region_size, crop_size)
|
||||||
|
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
|
||||||
|
if region_size[0] == crop_size[0]:
|
||||||
|
start_y = crop_region[0]
|
||||||
|
else:
|
||||||
|
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
|
||||||
|
if region_size[1] == crop_size[1]:
|
||||||
|
start_x = crop_region[2]
|
||||||
|
else:
|
||||||
|
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
|
||||||
|
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
|
||||||
|
|
||||||
|
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
|
||||||
|
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
|
||||||
|
max_choice = candidates[0]
|
||||||
|
min_choice = candidates[0]
|
||||||
|
max_dist = 0
|
||||||
|
min_dist = np.infty
|
||||||
|
with torch.no_grad():
|
||||||
|
for c in candidates:
|
||||||
|
start_y, end_y, start_x, end_x = c
|
||||||
|
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
|
||||||
|
if dist > max_dist:
|
||||||
|
max_dist = dist
|
||||||
|
max_choice = c
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
min_choice = c
|
||||||
|
if min_diff:
|
||||||
|
return min_choice
|
||||||
|
else:
|
||||||
|
return max_choice
|
||||||
|
|
||||||
|
def get_split_list(divisor, dividend):
|
||||||
|
split_list = [dividend // divisor for _ in range(divisor - 1)]
|
||||||
|
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
|
||||||
|
return split_list
|
||||||
|
|
||||||
|
def random_sampler(pic_size, crop_dict):
|
||||||
|
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||||
|
crop_res_dict = {}
|
||||||
|
for k, v in crop_dict.items():
|
||||||
|
crop_size = (int(k), int(k))
|
||||||
|
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||||
|
return crop_res_dict
|
||||||
|
|
||||||
|
def region_sampler(crop_region, crop_dict):
|
||||||
|
crop_res_dict = {}
|
||||||
|
for k, v in crop_dict.items():
|
||||||
|
crop_size = (int(k), int(k))
|
||||||
|
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||||
|
return crop_res_dict
|
||||||
|
|
||||||
|
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
|
||||||
|
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||||
|
crop_res_dict = {}
|
||||||
|
for k, v in crop_dict.items():
|
||||||
|
crop_size = (int(k), int(k))
|
||||||
|
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
|
||||||
|
return crop_res_dict
|
||||||
|
|
||||||
|
# TODO more flexible
|
||||||
|
def pyramid_sampler(pic_size, crop_dict):
|
||||||
|
crop_res_dict = {}
|
||||||
|
sorted_key = list(crop_dict.keys())
|
||||||
|
sorted_key.sort(key=lambda x: int(x), reverse=True)
|
||||||
|
k = sorted_key[0]
|
||||||
|
crop_size = (int(k), int(k))
|
||||||
|
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||||
|
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
|
||||||
|
|
||||||
|
for i in range(1, len(sorted_key)):
|
||||||
|
crop_res_dict[sorted_key[i]] = []
|
||||||
|
afore_num = crop_dict[sorted_key[i-1]]
|
||||||
|
new_num = crop_dict[sorted_key[i]]
|
||||||
|
split_list = get_split_list(afore_num, new_num)
|
||||||
|
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
|
||||||
|
for j in range(len(split_list)):
|
||||||
|
crop_region = crop_res_dict[sorted_key[i-1]][j]
|
||||||
|
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
|
||||||
|
|
||||||
|
return crop_res_dict
|
||||||
|
|
149
codes/models/SPSR_modules/spectral_norm.py
Normal file
149
codes/models/SPSR_modules/spectral_norm.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
'''
|
||||||
|
Copy from pytorch github repo
|
||||||
|
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
from torch.nn.functional import normalize
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralNorm(object):
|
||||||
|
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||||
|
self.name = name
|
||||||
|
self.dim = dim
|
||||||
|
if n_power_iterations <= 0:
|
||||||
|
raise ValueError('Expected n_power_iterations to be positive, but '
|
||||||
|
'got n_power_iterations={}'.format(n_power_iterations))
|
||||||
|
self.n_power_iterations = n_power_iterations
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def compute_weight(self, module):
|
||||||
|
weight = getattr(module, self.name + '_orig')
|
||||||
|
u = getattr(module, self.name + '_u')
|
||||||
|
weight_mat = weight
|
||||||
|
if self.dim != 0:
|
||||||
|
# permute dim to front
|
||||||
|
weight_mat = weight_mat.permute(self.dim,
|
||||||
|
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
||||||
|
height = weight_mat.size(0)
|
||||||
|
weight_mat = weight_mat.reshape(height, -1)
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(self.n_power_iterations):
|
||||||
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
||||||
|
# are the first left and right singular vectors.
|
||||||
|
# This power iteration produces approximations of `u` and `v`.
|
||||||
|
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
||||||
|
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
||||||
|
|
||||||
|
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
||||||
|
weight = weight / sigma
|
||||||
|
return weight, u
|
||||||
|
|
||||||
|
def remove(self, module):
|
||||||
|
weight = getattr(module, self.name)
|
||||||
|
delattr(module, self.name)
|
||||||
|
delattr(module, self.name + '_u')
|
||||||
|
delattr(module, self.name + '_orig')
|
||||||
|
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
||||||
|
|
||||||
|
def __call__(self, module, inputs):
|
||||||
|
if module.training:
|
||||||
|
weight, u = self.compute_weight(module)
|
||||||
|
setattr(module, self.name, weight)
|
||||||
|
setattr(module, self.name + '_u', u)
|
||||||
|
else:
|
||||||
|
r_g = getattr(module, self.name + '_orig').requires_grad
|
||||||
|
getattr(module, self.name).detach_().requires_grad_(r_g)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(module, name, n_power_iterations, dim, eps):
|
||||||
|
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||||
|
weight = module._parameters[name]
|
||||||
|
height = weight.size(dim)
|
||||||
|
|
||||||
|
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
||||||
|
delattr(module, fn.name)
|
||||||
|
module.register_parameter(fn.name + "_orig", weight)
|
||||||
|
# We still need to assign weight back as fn.name because all sorts of
|
||||||
|
# things may assume that it exists, e.g., when initializing weights.
|
||||||
|
# However, we can't directly assign as it could be an nn.Parameter and
|
||||||
|
# gets added as a parameter. Instead, we register weight.data as a
|
||||||
|
# buffer, which will cause weight to be included in the state dict
|
||||||
|
# and also supports nn.init due to shared storage.
|
||||||
|
module.register_buffer(fn.name, weight.data)
|
||||||
|
module.register_buffer(fn.name + "_u", u)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
|
||||||
|
r"""Applies spectral normalization to a parameter in the given module.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
||||||
|
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
||||||
|
|
||||||
|
Spectral normalization stabilizes the training of discriminators (critics)
|
||||||
|
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
||||||
|
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
||||||
|
power iteration method. If the dimension of the weight tensor is greater
|
||||||
|
than 2, it is reshaped to 2D in power iteration method to get spectral
|
||||||
|
norm. This is implemented via a hook that calculates spectral norm and
|
||||||
|
rescales weight before every :meth:`~Module.forward` call.
|
||||||
|
|
||||||
|
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
||||||
|
|
||||||
|
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): containing module
|
||||||
|
name (str, optional): name of weight parameter
|
||||||
|
n_power_iterations (int, optional): number of power iterations to
|
||||||
|
calculate spectal norm
|
||||||
|
eps (float, optional): epsilon for numerical stability in
|
||||||
|
calculating norms
|
||||||
|
dim (int, optional): dimension corresponding to number of outputs,
|
||||||
|
the default is 0, except for modules that are instances of
|
||||||
|
ConvTranspose1/2/3d, when it is 1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original module with the spectal norm hook
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> m = spectral_norm(nn.Linear(20, 40))
|
||||||
|
Linear (20 -> 40)
|
||||||
|
>>> m.weight_u.size()
|
||||||
|
torch.Size([20])
|
||||||
|
|
||||||
|
"""
|
||||||
|
if dim is None:
|
||||||
|
if isinstance(
|
||||||
|
module,
|
||||||
|
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
|
||||||
|
dim = 1
|
||||||
|
else:
|
||||||
|
dim = 0
|
||||||
|
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def remove_spectral_norm(module, name='weight'):
|
||||||
|
r"""Removes the spectral normalization reparameterization from a module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): containing module
|
||||||
|
name (str, optional): name of weight parameter
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> m = spectral_norm(nn.Linear(40, 10))
|
||||||
|
>>> remove_spectral_norm(m)
|
||||||
|
"""
|
||||||
|
for k, hook in module._forward_pre_hooks.items():
|
||||||
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||||
|
hook.remove(module)
|
||||||
|
del module._forward_pre_hooks[k]
|
||||||
|
return module
|
||||||
|
|
||||||
|
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
|
161
codes/models/SPSR_networks.py
Normal file
161
codes/models/SPSR_networks.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import init
|
||||||
|
|
||||||
|
import models.SPSR_modules.architecture as arch
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
####################
|
||||||
|
# initialize
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init_normal(m, std=0.02):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Conv') != -1:
|
||||||
|
init.normal_(m.weight.data, 0.0, std)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('Linear') != -1:
|
||||||
|
init.normal_(m.weight.data, 0.0, std)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('BatchNorm2d') != -1:
|
||||||
|
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init_kaiming(m, scale=1):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Conv') != -1:
|
||||||
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('Linear') != -1:
|
||||||
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('BatchNorm2d') != -1:
|
||||||
|
if m.affine != False:
|
||||||
|
|
||||||
|
init.constant_(m.weight.data, 1.0)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init_orthogonal(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Conv') != -1:
|
||||||
|
init.orthogonal_(m.weight.data, gain=1)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('Linear') != -1:
|
||||||
|
init.orthogonal_(m.weight.data, gain=1)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif classname.find('BatchNorm2d') != -1:
|
||||||
|
init.constant_(m.weight.data, 1.0)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
|
||||||
|
# scale for 'kaiming', std for 'normal'.
|
||||||
|
if init_type == 'normal':
|
||||||
|
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
|
||||||
|
net.apply(weights_init_normal_)
|
||||||
|
elif init_type == 'kaiming':
|
||||||
|
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
|
||||||
|
net.apply(weights_init_kaiming_)
|
||||||
|
elif init_type == 'orthogonal':
|
||||||
|
net.apply(weights_init_orthogonal)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# define network
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
def define_G(opt, device=None):
|
||||||
|
opt_net = opt['network_G']
|
||||||
|
which_model = opt_net['which_model_G']
|
||||||
|
|
||||||
|
if which_model == 'spsr_net':
|
||||||
|
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||||
|
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||||
|
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
if opt['is_train']:
|
||||||
|
init_weights(netG, init_type='kaiming', scale=0.1)
|
||||||
|
|
||||||
|
return netG
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Discriminator
|
||||||
|
def define_D(opt):
|
||||||
|
opt_net = opt['network_D']
|
||||||
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
|
if which_model == 'discriminator_vgg_128':
|
||||||
|
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
|
||||||
|
elif which_model == 'discriminator_vgg_96':
|
||||||
|
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
elif which_model == 'discriminator_vgg_192':
|
||||||
|
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
elif which_model == 'discriminator_vgg_128_SN':
|
||||||
|
netD = arch.Discriminator_VGG_128_SN()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
init_weights(netD, init_type='kaiming', scale=1)
|
||||||
|
|
||||||
|
return netD
|
||||||
|
|
||||||
|
def define_D_grad(opt):
|
||||||
|
opt_net = opt['network_D']
|
||||||
|
which_model = opt_net['which_model_D']
|
||||||
|
|
||||||
|
if which_model == 'discriminator_vgg_128':
|
||||||
|
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
|
||||||
|
elif which_model == 'discriminator_vgg_96':
|
||||||
|
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
elif which_model == 'discriminator_vgg_192':
|
||||||
|
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||||
|
elif which_model == 'discriminator_vgg_128_SN':
|
||||||
|
netD = arch.Discriminator_VGG_128_SN()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
|
||||||
|
init_weights(netD, init_type='kaiming', scale=1)
|
||||||
|
|
||||||
|
return netD
|
||||||
|
|
||||||
|
|
||||||
|
def define_F(opt, use_bn=False):
|
||||||
|
device = torch.device('cuda')
|
||||||
|
# pytorch pretrained VGG19-54, before ReLU.
|
||||||
|
if use_bn:
|
||||||
|
feature_layer = 49
|
||||||
|
else:
|
||||||
|
feature_layer = 34
|
||||||
|
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
|
||||||
|
use_input_norm=True, device=device)
|
||||||
|
|
||||||
|
netF.eval()
|
||||||
|
return netF
|
|
@ -11,6 +11,8 @@ def create_model(opt):
|
||||||
from .SRGAN_model import SRGANModel as M
|
from .SRGAN_model import SRGANModel as M
|
||||||
elif model == 'feat':
|
elif model == 'feat':
|
||||||
from .feature_model import FeatureModel as M
|
from .feature_model import FeatureModel as M
|
||||||
|
if model == 'spsr':
|
||||||
|
from .SPSR_model import SPSRModel as M
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||||
m = M(opt)
|
m = M(opt)
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user