Integrate SPSR into SRGAN_model

SPSR_model really isn't that different from SRGAN_model. Rather than continuing to re-implement
everything I've done in SRGAN_model, port the new stuff from SPSR over.

This really demonstrates the need to refactor SRGAN_model a bit to make it cleaner. It is quite the
beast these days..
This commit is contained in:
James Betker 2020-08-02 12:55:08 -06:00
parent c8da78966b
commit 328afde9c0
14 changed files with 542 additions and 875 deletions

View File

@ -7,84 +7,14 @@ import torch.nn as nn
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from apex import amp from apex import amp
import models.SPSR_networks as networks import models.networks as networks
from .base_model import BaseModel from .base_model import BaseModel
from models.SPSR_modules.loss import GANLoss from models.loss import GANLoss
import torchvision.utils as utils import torchvision.utils as utils
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
logger = logging.getLogger('base') logger = logging.getLogger('base')
import torch.nn.functional as F
class Get_gradient(nn.Module):
def __init__(self):
super(Get_gradient, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class SPSRModel(BaseModel): class SPSRModel(BaseModel):
def __init__(self, opt): def __init__(self, opt):
super(SPSRModel, self).__init__(opt) super(SPSRModel, self).__init__(opt)
@ -93,8 +23,8 @@ class SPSRModel(BaseModel):
# define networks and load pretrained models # define networks and load pretrained models
self.netG = networks.define_G(opt).to(self.device) # G self.netG = networks.define_G(opt).to(self.device) # G
if self.is_train: if self.is_train:
self.netD = networks.define_D(opt).to(self.device) # D self.netD = networks.define_D(opt).to(self.device) # D
self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
self.netG.train() self.netG.train()
self.netD.train() self.netD.train()
self.netD_grad.train() self.netD_grad.train()
@ -142,8 +72,8 @@ class SPSRModel(BaseModel):
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
# Branch_init_iters # Branch_init_iters
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0 self.branch_pretrain = train_opt['branch_pretrain'] if train_opt['branch_pretrain'] else 0
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 self.branch_init_iters = train_opt['branch_init_iters'] if train_opt['branch_init_iters'] else 1
# gradient_pixel_loss # gradient_pixel_loss
if train_opt['gradient_pixel_weight'] > 0: if train_opt['gradient_pixel_weight'] > 0:
@ -217,8 +147,8 @@ class SPSRModel(BaseModel):
raise NotImplementedError('MultiStepLR learning rate scheme is enough.') raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict() self.log_dict = OrderedDict()
self.get_grad = Get_gradient() self.get_grad = ImageGradient()
self.get_grad_nopadding = Get_gradient_nopadding() self.get_grad_nopadding = ImageGradientNoPadding()
def feed_data(self, data, need_HR=True): def feed_data(self, data, need_HR=True):
# LR # LR
@ -232,6 +162,12 @@ class SPSRModel(BaseModel):
def optimize_parameters(self, step): def optimize_parameters(self, step):
# Some generators have variants depending on the current step.
if hasattr(self.netG.module, "update_for_step"):
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
if hasattr(self.netD.module, "update_for_step"):
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# G # G
for p in self.netD.parameters(): for p in self.netD.parameters():
p.requires_grad = False p.requires_grad = False
@ -239,9 +175,8 @@ class SPSRModel(BaseModel):
for p in self.netD_grad.parameters(): for p in self.netD_grad.parameters():
p.requires_grad = False p.requires_grad = False
if(self.branch_pretrain):
if(self.Branch_pretrain): if(step < self.branch_init_iters):
if(step < self.Branch_init_iters):
for k,v in self.netG.named_parameters(): for k,v in self.netG.named_parameters():
if 'f_' not in k : if 'f_' not in k :
v.requires_grad=False v.requires_grad=False
@ -250,7 +185,6 @@ class SPSRModel(BaseModel):
if 'f_' not in k : if 'f_' not in k :
v.requires_grad=True v.requires_grad=True
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.fake_H_branch = [] self.fake_H_branch = []
@ -361,43 +295,49 @@ class SPSRModel(BaseModel):
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True)
# fed_LQ is not chunked. # fed_LQ is not chunked.
utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,))) utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,))) utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,))) utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
utils.save_image(self.grad_LR[0].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i.png" % (step,)))
# set log # set log
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
# G # G
if self.cri_pix: if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item() self.add_log_entry('l_g_pix', l_g_pix.item())
if self.cri_fea: if self.cri_fea:
self.log_dict['l_g_fea'] = l_g_fea.item() self.add_log_entry('l_g_fea', l_g_fea.item())
if self.l_gan_w > 0: if self.l_gan_w > 0:
self.log_dict['l_g_gan'] = l_g_gan.item() self.add_log_entry('l_g_gan', l_g_gan.item())
if self.cri_pix_branch: #branch pixel loss if self.cri_pix_branch: #branch pixel loss
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item() self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
if self.l_gan_w > 0: if self.l_gan_w > 0:
# D self.add_log_entry('l_d_real', l_d_real.item())
self.log_dict['l_d_real'] = l_d_real.item() self.add_log_entry('l_d_fake', l_d_fake.item())
self.log_dict['l_d_fake'] = l_d_fake.item() self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
self.add_log_entry('D_real', torch.mean(pred_d_real.detach()))
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_real_grad', torch.mean(pred_d_real_grad.detach()))
self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach()))
# D_grad # Allows the log to serve as an easy-to-use rotating buffer.
self.log_dict['l_d_real_grad'] = l_d_real_grad.item() def add_log_entry(self, key, value):
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() key_it = "%s_it" % (key,)
log_rotating_buffer_size = 50
if self.opt['train']['gan_type'] == 'wgan-gp': if key not in self.log_dict.keys():
self.log_dict['l_d_gp'] = l_d_gp.item() self.log_dict[key] = []
# D outputs self.log_dict[key_it] = 0
self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) if len(self.log_dict[key]) < log_rotating_buffer_size:
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) self.log_dict[key].append(value)
else:
# D_grad outputs self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) self.log_dict[key_it] += 1
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
def test(self): def test(self):
self.netG.eval() self.netG.eval()
@ -413,8 +353,21 @@ class SPSRModel(BaseModel):
self.netG.train() self.netG.train()
# Fetches a summary of the log.
def get_current_log(self, step): def get_current_log(self, step):
return self.log_dict return_log = {}
for k in self.log_dict.keys():
if not isinstance(self.log_dict[k], list):
continue
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
# Some generators can do their own metric logging.
if hasattr(self.netG.module, "get_debug_values"):
return_log.update(self.netG.module.get_debug_values(step))
if hasattr(self.netD.module, "get_debug_values"):
return_log.update(self.netD.module.get_debug_values(step))
return return_log
def get_current_visuals(self, need_HR=True): def get_current_visuals(self, need_HR=True):
out_dict = OrderedDict() out_dict = OrderedDict()
@ -470,6 +423,10 @@ class SPSRModel(BaseModel):
if self.opt['is_train'] and load_path_D is not None: if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD) self.load_network(load_path_D, self.netD)
load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
if self.opt['is_train'] and load_path_D_grad is not None:
logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad))
self.load_network(load_path_D_grad, self.netD_grad)
def compute_fea_loss(self, real, fake): def compute_fea_loss(self, real, fake):
if self.cri_fea is None: if self.cri_fea is None:

View File

@ -1,366 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from . import block as B
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
def forward(self, x):
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x_list.append(x_i)
x = torch.cat(x_list, dim = 1)
return x
####################
# Generator
####################
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, self.HR_conv0_new)
self.get_g_nopadding = Get_gradient_nopadding()
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA')
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i+5](x)
x_fea2 = x
for i in range(5):
x = block_list[i+10](x)
x_fea3 = x
for i in range(5):
x = block_list[i+15](x)
x_fea4 = x
x = block_list[20:](x)
#short cut
x = x_ori+x
x= self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
#short cut
x_cat_4 = x_cat_4+x_b_fea
x_branch = self.b_module(x_cat_4)
x_out_branch = self.conv_w(x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_grad
####################
# Discriminator
####################
# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
####################
# Perceptual Network
####################
class VGGFeatureExtractor(nn.Module):
def __init__(self,
feature_layer=34,
use_bn=False,
use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
model = torchvision.models.resnet101(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.children())[:8])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class MINCNet(nn.Module):
def __init__(self):
super(MINCNet, self).__init__()
self.ReLU = nn.ReLU(True)
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
def forward(self, x):
out = self.ReLU(self.conv11(x))
out = self.ReLU(self.conv12(out))
out = self.maxpool1(out)
out = self.ReLU(self.conv21(out))
out = self.ReLU(self.conv22(out))
out = self.maxpool2(out)
out = self.ReLU(self.conv31(out))
out = self.ReLU(self.conv32(out))
out = self.ReLU(self.conv33(out))
out = self.maxpool3(out)
out = self.ReLU(self.conv41(out))
out = self.ReLU(self.conv42(out))
out = self.ReLU(self.conv43(out))
out = self.maxpool4(out)
out = self.ReLU(self.conv51(out))
out = self.ReLU(self.conv52(out))
out = self.conv53(out)
return out
class MINCFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
device=torch.device('cpu')):
super(MINCFeatureExtractor, self).__init__()
self.features = MINCNet()
self.features.load_state_dict(
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
self.features.eval()
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
output = self.features(x)
return output

View File

@ -1,29 +0,0 @@
import torch
import torch.nn as nn
# Define GAN loss: [vanilla | lsgan]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss

View File

@ -1,94 +0,0 @@
import random
import torch
import numpy as np
def _get_random_crop_indices(crop_region, crop_size):
'''
crop_region: (strat_y, end_y, start_x, end_x)
crop_size: (y, x)
'''
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
print(region_size, crop_size)
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
if region_size[0] == crop_size[0]:
start_y = crop_region[0]
else:
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
if region_size[1] == crop_size[1]:
start_x = crop_region[2]
else:
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
max_choice = candidates[0]
min_choice = candidates[0]
max_dist = 0
min_dist = np.infty
with torch.no_grad():
for c in candidates:
start_y, end_y, start_x, end_x = c
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
if dist > max_dist:
max_dist = dist
max_choice = c
if dist < min_dist:
min_dist = dist
min_choice = c
if min_diff:
return min_choice
else:
return max_choice
def get_split_list(divisor, dividend):
split_list = [dividend // divisor for _ in range(divisor - 1)]
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
return split_list
def random_sampler(pic_size, crop_dict):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def region_sampler(crop_region, crop_dict):
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
return crop_res_dict
# TODO more flexible
def pyramid_sampler(pic_size, crop_dict):
crop_res_dict = {}
sorted_key = list(crop_dict.keys())
sorted_key.sort(key=lambda x: int(x), reverse=True)
k = sorted_key[0]
crop_size = (int(k), int(k))
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
for i in range(1, len(sorted_key)):
crop_res_dict[sorted_key[i]] = []
afore_num = crop_dict[sorted_key[i-1]]
new_num = crop_dict[sorted_key[i]]
split_list = get_split_list(afore_num, new_num)
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
for j in range(len(split_list)):
crop_region = crop_res_dict[sorted_key[i-1]][j]
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
return crop_res_dict

View File

@ -1,161 +0,0 @@
import functools
import logging
import torch
import torch.nn as nn
from torch.nn import init
import models.SPSR_modules.architecture as arch
logger = logging.getLogger('base')
####################
# initialize
####################
def weights_init_normal(m, std=0.02):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
init.constant_(m.bias.data, 0.0)
def weights_init_kaiming(m, scale=1):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
if m.affine != False:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
# scale for 'kaiming', std for 'normal'.
if init_type == 'normal':
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
net.apply(weights_init_normal_)
elif init_type == 'kaiming':
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
net.apply(weights_init_kaiming_)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
####################
# define network
####################
# Generator
def define_G(opt, device=None):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
if which_model == 'spsr_net':
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
if opt['is_train']:
init_weights(netG, init_type='kaiming', scale=0.1)
return netG
# Discriminator
def define_D(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_D_grad(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_F(opt, use_bn=False):
device = torch.device('cuda')
# pytorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
use_input_norm=True, device=device)
netF.eval()
return netF

View File

@ -9,6 +9,7 @@ from models.base_model import BaseModel
from models.loss import GANLoss, FDPLLoss from models.loss import GANLoss, FDPLLoss
from apex import amp from apex import amp
from data.weight_scheduler import get_scheduler_for_opt from data.weight_scheduler import get_scheduler_for_opt
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
import torch.nn.functional as F import torch.nn.functional as F
import glob import glob
import random import random
@ -27,11 +28,18 @@ class SRGANModel(BaseModel):
else: else:
self.rank = -1 # non dist training self.rank = -1 # non dist training
train_opt = opt['train'] train_opt = opt['train']
self.spsr_enabled = 'spsr' in opt['model']
# Only pixgan and gan are currently supported in spsr_mode
if self.spsr_enabled:
assert train_opt['gan_type'] == 'pixgan' or train_opt['gan_type'] == 'gan'
# define networks and load pretrained models # define networks and load pretrained models
self.netG = networks.define_G(opt).to(self.device) self.netG = networks.define_G(opt).to(self.device)
if self.is_train: if self.is_train:
self.netD = networks.define_D(opt).to(self.device) self.netD = networks.define_D(opt).to(self.device)
if self.spsr_enabled:
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
if 'network_C' in opt.keys(): if 'network_C' in opt.keys():
self.netC = networks.define_G(opt, net_key='network_C').to(self.device) self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
@ -73,6 +81,33 @@ class SRGANModel(BaseModel):
else: else:
self.fdpl_enabled = False self.fdpl_enabled = False
if self.spsr_enabled:
spsr_opt = train_opt['spsr']
self.branch_pretrain = spsr_opt['branch_pretrain'] if spsr_opt['branch_pretrain'] else 0
self.branch_init_iters = spsr_opt['branch_init_iters'] if spsr_opt['branch_init_iters'] else 1
if spsr_opt['gradient_pixel_weight'] > 0:
self.cri_pix_grad = nn.MSELoss().to(self.device)
self.l_pix_grad_w = spsr_opt['gradient_pixel_weight']
else:
self.cri_pix_grad = None
if spsr_opt['gradient_gan_weight'] > 0:
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
self.l_gan_grad_w = spsr_opt['gradient_gan_weight']
else:
self.cri_grad_gan = None
if spsr_opt['pixel_branch_weight'] > 0:
l_pix_type = spsr_opt['pixel_branch_criterion']
if l_pix_type == 'l1':
self.cri_pix_branch = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix_branch = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_branch_w = spsr_opt['pixel_branch_weight']
else:
logger.info('Remove G_grad pixel loss.')
self.cri_pix_branch = None
# G feature loss # G feature loss
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0: if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
# For backwards compatibility, use a scheduler definition instead. Remove this at some point. # For backwards compatibility, use a scheduler definition instead. Remove this at some point.
@ -139,7 +174,7 @@ class SRGANModel(BaseModel):
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5 self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
# optimizers # optimizers
# G # G optimizer
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = [] optim_params = []
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR': if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
@ -155,6 +190,7 @@ class SRGANModel(BaseModel):
weight_decay=wd_G, weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G'])) betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
# D optimizer
optim_params = [] optim_params = []
for k, v in self.netD.named_parameters(): # can optimize for a part of the model for k, v in self.netD.named_parameters(): # can optimize for a part of the model
if v.requires_grad: if v.requires_grad:
@ -162,16 +198,40 @@ class SRGANModel(BaseModel):
else: else:
if self.rank <= 0: if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k)) logger.warning('Params [{:s}] will not optimize.'.format(k))
# D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'], self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
weight_decay=wd_D, weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D'])) betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
# AMP if self.spsr_enabled:
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ # D_grad optimizer
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) optim_params = []
for k, v in self.netD_grad.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
# D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D_grad = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D_grad)
if self.spsr_enabled:
self.get_grad = ImageGradient().to(self.device)
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
[self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
opt_level=self.amp_level, num_losses=3)
else:
# AMP
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
# DataParallel # DataParallel
if opt['dist']: if opt['dist']:
@ -188,6 +248,8 @@ class SRGANModel(BaseModel):
self.netD = DataParallel(self.netD) self.netD = DataParallel(self.netD)
self.netG.train() self.netG.train()
self.netD.train() self.netD.train()
if self.spsr_enabled:
self.netD_grad.train()
# schedulers # schedulers
if train_opt['lr_scheme'] == 'MultiStepLR': if train_opt['lr_scheme'] == 'MultiStepLR':
@ -208,6 +270,10 @@ class SRGANModel(BaseModel):
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'], self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
[0], [0],
train_opt['lr_gamma'])) train_opt['lr_gamma']))
if self.spsr_enabled:
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'],
[0],
train_opt['lr_gamma']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers: for optimizer in self.optimizers:
self.schedulers.append( self.schedulers.append(
@ -284,18 +350,22 @@ class SRGANModel(BaseModel):
# G # G
for p in self.netD.parameters(): for p in self.netD.parameters():
p.requires_grad = False p.requires_grad = False
if self.spsr_enabled:
if step >= self.D_init_iters: for p in self.netD_grad.parameters():
self.optimizer_G.zero_grad() p.requires_grad = False
self.swapout_D(step) self.swapout_D(step)
self.swapout_G(step) self.swapout_G(step)
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason. # Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
if step % self.D_update_ratio == 0 and step >= self.D_init_iters: if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
for p in self.netG.parameters(): if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters:
if p.dtype != torch.int64 and p.dtype != torch.bool: for k, v in self.netG.named_parameters():
p.requires_grad = True v.requires_grad = '_branch_pretrain' in k
else:
for p in self.netG.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True
else: else:
for p in self.netG.parameters(): for p in self.netG.parameters():
p.requires_grad = False p.requires_grad = False
@ -310,17 +380,32 @@ class SRGANModel(BaseModel):
print("Misc setup %f" % (time() - _t,)) print("Misc setup %f" % (time() - _t,))
_t = time() _t = time()
if step >= self.D_init_iters:
self.optimizer_G.zero_grad()
self.fake_GenOut = [] self.fake_GenOut = []
self.fea_GenOut = [] self.fea_GenOut = []
self.fake_H = [] self.fake_H = []
self.spsr_grad_GenOut = []
var_ref_skips = [] var_ref_skips = []
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix): for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
if random.random() > self.gan_lq_img_use_prob: if self.spsr_enabled:
fea_GenOut, fake_GenOut = self.netG(var_L) # SPSR models have outputs from three different branches.
fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L)
fea_GenOut = fake_GenOut
using_gan_img = False using_gan_img = False
# Get image gradients for later use.
fake_H_grad = self.get_grad(fake_GenOut)
var_H_grad = self.get_grad(var_H)
var_ref_grad = self.get_grad(var_ref)
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
self.spsr_grad_GenOut.append(grad_LR)
else: else:
fea_GenOut, fake_GenOut = self.netG(var_LGAN) if random.random() > self.gan_lq_img_use_prob:
using_gan_img = True fea_GenOut, fake_GenOut = self.netG(var_L)
using_gan_img = False
else:
fea_GenOut, fake_GenOut = self.netG(var_LGAN)
using_gan_img = True
if _profile: if _profile:
print("Gen forward %f" % (time() - _t,)) print("Gen forward %f" % (time() - _t,))
@ -339,6 +424,13 @@ class SRGANModel(BaseModel):
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix) l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
l_g_pix_log = l_g_pix / self.l_pix_w l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix l_g_total += l_g_pix
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad)
l_g_total += l_g_pix_grad
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch,
var_H_grad_nopadding)
l_g_total += l_g_pix_grad_branch
if self.fdpl_enabled and not using_gan_img: if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
l_g_total += l_g_fdpl * self.fdpl_weight l_g_total += l_g_fdpl * self.fdpl_weight
@ -370,6 +462,7 @@ class SRGANModel(BaseModel):
l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea) l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fix_disc l_g_total += l_g_fix_disc
if self.l_gan_w > 0: if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']: if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
pred_g_fake = self.netD(fake_GenOut) pred_g_fake = self.netD(fake_GenOut)
@ -383,6 +476,14 @@ class SRGANModel(BaseModel):
l_g_gan_log = l_g_gan / self.l_gan_w l_g_gan_log = l_g_gan / self.l_gan_w
l_g_total += l_g_gan l_g_total += l_g_gan
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
pred_g_fake_grad = self.netD_grad(fake_H_grad)
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
l_g_total += l_g_gan_grad
# Scale the loss down by the batch factor. # Scale the loss down by the batch factor.
l_g_total_log = l_g_total l_g_total_log = l_g_total
l_g_total = l_g_total / self.mega_batch_factor l_g_total = l_g_total / self.mega_batch_factor
@ -418,8 +519,10 @@ class SRGANModel(BaseModel):
gen_input = var_LGAN gen_input = var_LGAN
# Re-compute generator outputs (post-update). # Re-compute generator outputs (post-update).
with torch.no_grad(): with torch.no_grad():
_, fake_H = self.netG(gen_input) if self.spsr_enabled:
# The following line detaches all generator outputs that are not None. _, fake_H, _ = self.netG(gen_input)
else:
_, fake_H = self.netG(gen_input)
fake_H = fake_H.detach() fake_H = fake_H.detach()
if _profile: if _profile:
@ -546,11 +649,36 @@ class SRGANModel(BaseModel):
self.fake_H.append(fake_H.detach()) self.fake_H.append(fake_H.detach())
self.optimizer_D.step() self.optimizer_D.step()
if _profile: if _profile:
print("Disc step %f" % (time() - _t,)) print("Disc step %f" % (time() - _t,))
_t = time() _t = time()
# D_grad.
if self.spsr_enabled and self.cri_grad_gan and step >= self.G_warmup:
for p in self.netD_grad.parameters():
p.requires_grad = True
self.optimizer_D_grad.zero_grad()
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
fake_H_grad = self.get_grad(fake_H)
var_ref_grad = self.get_grad(var_ref)
pred_d_real_grad = self.netD_grad(var_ref_grad)
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
if self.opt['train']['gan_type'] == 'gan':
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
elif self.opt['train']['gan_type'] == 'pixgan':
real = torch.ones_like(pred_d_real_grad)
fake = torch.zeros_like(pred_d_fake_grad)
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake)
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad /= self.mega_batch_factor
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
l_d_total_grad_scaled.backward()
self.optimizer_D_grad.step()
# Log sample images from first microbatch. # Log sample images from first microbatch.
if step % self.img_debug_steps == 0: if step % self.img_debug_steps == 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp") sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
@ -562,6 +690,8 @@ class SRGANModel(BaseModel):
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True) os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
if self.spsr_enabled:
os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True)
# fed_LQ is not chunked. # fed_LQ is not chunked.
for i in range(self.mega_batch_factor): for i in range(self.mega_batch_factor):
@ -570,6 +700,8 @@ class SRGANModel(BaseModel):
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i))) utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i))) utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i))) utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
if self.spsr_enabled:
utils.save_image(self.spsr_grad_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i_%02i.png" % (step, i)))
if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']: if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i))) utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i))) utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
@ -594,11 +726,19 @@ class SRGANModel(BaseModel):
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor) self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor) self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
if self.spsr_enabled:
if self.cri_pix_branch:
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
if self.l_gan_w > 0 and step >= self.G_warmup: if self.l_gan_w > 0 and step >= self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
if self.spsr_enabled:
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad))
# Log learning rates. # Log learning rates.
for i, pg in enumerate(self.optimizer_G.param_groups): for i, pg in enumerate(self.optimizer_G.param_groups):
@ -685,7 +825,16 @@ class SRGANModel(BaseModel):
def test(self): def test(self):
self.netG.eval() self.netG.eval()
with torch.no_grad(): with torch.no_grad():
self.fake_GenOut = [self.netG(self.var_L[0])] if self.spsr_enabled:
self.fake_H_branch = []
self.fake_GenOut = []
self.grad_LR = []
fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0])
self.fake_H_branch.append(fake_H_branch)
self.fake_GenOut.append(fake_GenOut)
self.grad_LR.append(grad_LR)
else:
self.fake_GenOut = [self.netG(self.var_L[0])]
self.netG.train() self.netG.train()
# Fetches a summary of the log. # Fetches a summary of the log.
@ -713,6 +862,9 @@ class SRGANModel(BaseModel):
out_dict['rlt'] = gen_batch.detach().float().cpu() out_dict['rlt'] = gen_batch.detach().float().cpu()
if need_GT: if need_GT:
out_dict['GT'] = self.var_H[0].detach().float().cpu() out_dict['GT'] = self.var_H[0].detach().float().cpu()
if self.spsr_enabled:
out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu()
out_dict['LR_grad'] = self.grad_LR[0].float().cpu()
return out_dict return out_dict
def print_network(self): def print_network(self):
@ -762,6 +914,11 @@ class SRGANModel(BaseModel):
if self.opt['is_train'] and load_path_D is not None: if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
if self.spsr_enabled:
load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
if self.opt['is_train'] and load_path_D_grad is not None:
logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad))
self.load_network(load_path_D_grad, self.netD_grad)
def load_random_corruptor(self): def load_random_corruptor(self):
if self.netC is None: if self.netC is None:
@ -774,3 +931,4 @@ class SRGANModel(BaseModel):
def save(self, iter_step): def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step) self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step) self.save_network(self.netD, 'D', iter_step)
self.save_network(self.netD_grad, 'D_grad', iter_step)

View File

@ -7,11 +7,11 @@ def create_model(opt):
# image restoration # image restoration
if model == 'sr': # PSNR-oriented super resolution if model == 'sr': # PSNR-oriented super resolution
from .SR_model import SRModel as M from .SR_model import SRModel as M
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
from .SRGAN_model import SRGANModel as M from .SRGAN_model import SRGANModel as M
elif model == 'feat': elif model == 'feat':
from .feature_model import FeatureModel as M from .feature_model import FeatureModel as M
if model == 'spsr': elif model == 'spsr':
from .SPSR_model import SPSRModel as M from .SPSR_model import SPSRModel as M
else: else:
raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) raise NotImplementedError('Model [{:s}] not recognized.'.format(model))

View File

@ -0,0 +1,226 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.archs import SPSR_util as B
from .RRDBNet_arch import RRDB
class ImageGradient(nn.Module):
def __init__(self):
super(ImageGradient, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
x2 = x[:, 2]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class ImageGradientNoPadding(nn.Module):
def __init__(self):
super(ImageGradientNoPadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
def forward(self, x):
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x_list.append(x_i)
x = torch.cat(x_list, dim = 1)
return x
####################
# Generator
####################
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, self.HR_conv0_new)
self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_1 = RRDB(nf*2, gc=32)
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_2 = RRDB(nf*2, gc=32)
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_3 = RRDB(nf*2, gc=32)
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
self.b_block_4 = RRDB(nf*2, gc=32)
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
self._branch_pretrain_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
self._branch_pretrain_block = RRDB(nf*2, gc=32)
self._branch_pretrain_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i+5](x)
x_fea2 = x
for i in range(5):
x = block_list[i+10](x)
x_fea3 = x
for i in range(5):
x = block_list[i+15](x)
x_fea4 = x
x = block_list[20:](x)
#short cut
x = x_ori+x
x= self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
#short cut
x_cat_4 = x_cat_4+x_b_fea
x_branch = self.b_module(x_cat_4)
x_out_branch = self.conv_w(x_branch)
########
x_branch_d = x_branch
x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1)
x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat)
x_out = self._branch_pretrain_concat(x__branch_pretrain_cat)
x_out = self._branch_pretrain_HR_conv0(x_out)
x_out = self._branch_pretrain_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_grad

View File

@ -5,8 +5,6 @@ import torch.nn as nn
#################### ####################
# Basic blocks # Basic blocks
#################### ####################
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1): def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
# helper selecting activation # helper selecting activation
# neg_slope: for leakyrelu and init of prelu # neg_slope: for leakyrelu and init of prelu
@ -134,101 +132,6 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
return sequential(n, a, p, c) return sequential(n, a, p, c)
####################
# Useful blocks
####################
class ResNetBlock(nn.Module):
'''
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
'''
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
super(ResNetBlock, self).__init__()
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
if mode == 'CNA':
act_type = None
if mode == 'CNAC': # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
# if in_nc != out_nc:
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
# None, None)
# print('Need a projecter in ResNetBlock.')
# else:
# self.project = lambda x:x
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
res = self.res(x).mul(self.res_scale)
return x + res
class ResidualDenseBlock_5C(nn.Module):
'''
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x
class RRDB(nn.Module):
'''
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x
#################### ####################
# Upsampler # Upsampler
#################### ####################

View File

@ -78,6 +78,65 @@ class Discriminator_VGG_128(nn.Module):
return out return out
class Discriminator_VGG_128_GN(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1):
super(Discriminator_VGG_128_GN, self).__init__()
# [64, 128, 128]
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
# [64, 64, 64]
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
# [128, 32, 32]
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
# [256, 16, 16]
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
# [512, 8, 8]
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
final_nf = nf * 8
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
#fea = torch.cat([fea, skip_med], dim=1)
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
#fea = torch.cat([fea, skip_lo], dim=1)
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
fea = fea.contiguous().view(fea.size(0), -1)
fea = self.lrelu(self.linear1(fea))
out = self.linear2(fea)
return out
class Discriminator_VGG_PixLoss(nn.Module): class Discriminator_VGG_PixLoss(nn.Module):
def __init__(self, in_nc, nf): def __init__(self, in_nc, nf):
super(Discriminator_VGG_PixLoss, self).__init__() super(Discriminator_VGG_PixLoss, self).__init__()

View File

@ -11,6 +11,8 @@ import models.archs.feature_arch as feature_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SRG1_arch as srg1 import models.archs.SRG1_arch as srg1
import models.archs.ProgressiveSrg_arch as psrg import models.archs.ProgressiveSrg_arch as psrg
import models.archs.SPSR_arch as spsr
import models.archs.arch_util as arch_util
import functools import functools
from collections import OrderedDict from collections import OrderedDict
@ -97,6 +99,12 @@ def define_G(opt, net_key='network_G'):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
start_step=opt_net['start_step']) start_step=opt_net['start_step'])
elif which_model == 'spsr_net':
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
if opt['is_train']:
arch_util.initialize_weights(netG, scale=.1)
# image corruption # image corruption
elif which_model == 'HighToLowResNet': elif which_model == 'HighToLowResNet':
@ -119,6 +127,8 @@ def define_D_net(opt_net, img_sz=None):
if which_model == 'discriminator_vgg_128': if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv']) netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128)
elif which_model == 'discriminator_resnet': elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough': elif which_model == 'discriminator_resnet_passthrough':

View File

@ -115,7 +115,11 @@ def check_resume(opt, resume_iter):
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter)) '{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']: if 'gan' in opt['model'] or 'spsr' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter)) '{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
if 'spsr' in opt['model']:
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
'{}_D_grad.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -215,7 +215,7 @@ def main():
logger.info(message) logger.info(message)
#### validation #### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
model.force_restore_swapout() model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation # does not support multi-GPU validation