forked from mrq/DL-Art-School
Integrate SPSR into SRGAN_model
SPSR_model really isn't that different from SRGAN_model. Rather than continuing to re-implement everything I've done in SRGAN_model, port the new stuff from SPSR over. This really demonstrates the need to refactor SRGAN_model a bit to make it cleaner. It is quite the beast these days..
This commit is contained in:
parent
c8da78966b
commit
328afde9c0
|
@ -7,84 +7,14 @@ import torch.nn as nn
|
|||
from torch.optim import lr_scheduler
|
||||
from apex import amp
|
||||
|
||||
import models.SPSR_networks as networks
|
||||
import models.networks as networks
|
||||
from .base_model import BaseModel
|
||||
from models.SPSR_modules.loss import GANLoss
|
||||
from models.loss import GANLoss
|
||||
import torchvision.utils as utils
|
||||
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Get_gradient(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
class Get_gradient_nopadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient_nopadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class SPSRModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(SPSRModel, self).__init__(opt)
|
||||
|
@ -93,8 +23,8 @@ class SPSRModel(BaseModel):
|
|||
# define networks and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device) # G
|
||||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device) # D
|
||||
self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
|
||||
self.netD = networks.define_D(opt).to(self.device) # D
|
||||
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
|
||||
self.netG.train()
|
||||
self.netD.train()
|
||||
self.netD_grad.train()
|
||||
|
@ -142,8 +72,8 @@ class SPSRModel(BaseModel):
|
|||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
# Branch_init_iters
|
||||
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
|
||||
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
|
||||
self.branch_pretrain = train_opt['branch_pretrain'] if train_opt['branch_pretrain'] else 0
|
||||
self.branch_init_iters = train_opt['branch_init_iters'] if train_opt['branch_init_iters'] else 1
|
||||
|
||||
# gradient_pixel_loss
|
||||
if train_opt['gradient_pixel_weight'] > 0:
|
||||
|
@ -217,8 +147,8 @@ class SPSRModel(BaseModel):
|
|||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
self.get_grad = Get_gradient()
|
||||
self.get_grad_nopadding = Get_gradient_nopadding()
|
||||
self.get_grad = ImageGradient()
|
||||
self.get_grad_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def feed_data(self, data, need_HR=True):
|
||||
# LR
|
||||
|
@ -232,6 +162,12 @@ class SPSRModel(BaseModel):
|
|||
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
# Some generators have variants depending on the current step.
|
||||
if hasattr(self.netG.module, "update_for_step"):
|
||||
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
if hasattr(self.netD.module, "update_for_step"):
|
||||
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -239,9 +175,8 @@ class SPSRModel(BaseModel):
|
|||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
if(self.Branch_pretrain):
|
||||
if(step < self.Branch_init_iters):
|
||||
if(self.branch_pretrain):
|
||||
if(step < self.branch_init_iters):
|
||||
for k,v in self.netG.named_parameters():
|
||||
if 'f_' not in k :
|
||||
v.requires_grad=False
|
||||
|
@ -250,7 +185,6 @@ class SPSRModel(BaseModel):
|
|||
if 'f_' not in k :
|
||||
v.requires_grad=True
|
||||
|
||||
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.fake_H_branch = []
|
||||
|
@ -361,43 +295,49 @@ class SPSRModel(BaseModel):
|
|||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True)
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||
utils.save_image(self.grad_LR[0].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i.png" % (step,)))
|
||||
|
||||
|
||||
# set log
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
# G
|
||||
if self.cri_pix:
|
||||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
self.add_log_entry('l_g_pix', l_g_pix.item())
|
||||
if self.cri_fea:
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.add_log_entry('l_g_fea', l_g_fea.item())
|
||||
if self.l_gan_w > 0:
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
self.add_log_entry('l_g_gan', l_g_gan.item())
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
# D
|
||||
self.log_dict['l_d_real'] = l_d_real.item()
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
self.add_log_entry('l_d_real', l_d_real.item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake.item())
|
||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
|
||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
|
||||
self.add_log_entry('D_real', torch.mean(pred_d_real.detach()))
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_real_grad', torch.mean(pred_d_real_grad.detach()))
|
||||
self.add_log_entry('D_fake_grad', torch.mean(pred_d_fake_grad.detach()))
|
||||
|
||||
# D_grad
|
||||
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
|
||||
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
|
||||
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
self.log_dict['l_d_gp'] = l_d_gp.item()
|
||||
# D outputs
|
||||
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
|
||||
# D_grad outputs
|
||||
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
|
||||
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
|
||||
# Allows the log to serve as an easy-to-use rotating buffer.
|
||||
def add_log_entry(self, key, value):
|
||||
key_it = "%s_it" % (key,)
|
||||
log_rotating_buffer_size = 50
|
||||
if key not in self.log_dict.keys():
|
||||
self.log_dict[key] = []
|
||||
self.log_dict[key_it] = 0
|
||||
if len(self.log_dict[key]) < log_rotating_buffer_size:
|
||||
self.log_dict[key].append(value)
|
||||
else:
|
||||
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
|
||||
self.log_dict[key_it] += 1
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
|
@ -413,8 +353,21 @@ class SPSRModel(BaseModel):
|
|||
|
||||
self.netG.train()
|
||||
|
||||
# Fetches a summary of the log.
|
||||
def get_current_log(self, step):
|
||||
return self.log_dict
|
||||
return_log = {}
|
||||
for k in self.log_dict.keys():
|
||||
if not isinstance(self.log_dict[k], list):
|
||||
continue
|
||||
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
||||
|
||||
# Some generators can do their own metric logging.
|
||||
if hasattr(self.netG.module, "get_debug_values"):
|
||||
return_log.update(self.netG.module.get_debug_values(step))
|
||||
if hasattr(self.netD.module, "get_debug_values"):
|
||||
return_log.update(self.netD.module.get_debug_values(step))
|
||||
|
||||
return return_log
|
||||
|
||||
def get_current_visuals(self, need_HR=True):
|
||||
out_dict = OrderedDict()
|
||||
|
@ -470,6 +423,10 @@ class SPSRModel(BaseModel):
|
|||
if self.opt['is_train'] and load_path_D is not None:
|
||||
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
||||
self.load_network(load_path_D, self.netD)
|
||||
load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
|
||||
if self.opt['is_train'] and load_path_D_grad is not None:
|
||||
logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad))
|
||||
self.load_network(load_path_D_grad, self.netD_grad)
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
if self.cri_fea is None:
|
||||
|
|
|
@ -1,366 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from . import block as B
|
||||
|
||||
|
||||
class Get_gradient_nopadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient_nopadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
||||
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x_list = []
|
||||
for i in range(x.shape[1]):
|
||||
x_i = x[:, i]
|
||||
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
||||
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
||||
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
||||
x_list.append(x_i)
|
||||
|
||||
x = torch.cat(x_list, dim = 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Generator
|
||||
####################
|
||||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
||||
|
||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.get_g_nopadding = Get_gradient_nopadding()
|
||||
|
||||
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||
|
||||
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
|
||||
|
||||
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model[0](x)
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
#short cut
|
||||
x = x_ori+x
|
||||
x= self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
|
||||
#short cut
|
||||
x_cat_4 = x_cat_4+x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
|
||||
x_out_branch = self.conv_w(x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x_f_cat = self.f_block(x_f_cat)
|
||||
x_out = self.f_concat(x_f_cat)
|
||||
x_out = self.f_HR_conv0(x_out)
|
||||
x_out = self.f_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
####################
|
||||
# Discriminator
|
||||
####################
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 128*128
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_128, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 128, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 64, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 32, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Perceptual Network
|
||||
####################
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self,
|
||||
feature_layer=34,
|
||||
use_bn=False,
|
||||
use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=True)
|
||||
self.use_input_norm = use_input_norm
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class ResNet101FeatureExtractor(nn.Module):
|
||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||
super(ResNet101FeatureExtractor, self).__init__()
|
||||
model = torchvision.models.resnet101(pretrained=True)
|
||||
self.use_input_norm = use_input_norm
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.children())[:8])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class MINCNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MINCNet, self).__init__()
|
||||
self.ReLU = nn.ReLU(True)
|
||||
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
|
||||
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
|
||||
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
|
||||
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
|
||||
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
|
||||
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
|
||||
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.ReLU(self.conv11(x))
|
||||
out = self.ReLU(self.conv12(out))
|
||||
out = self.maxpool1(out)
|
||||
out = self.ReLU(self.conv21(out))
|
||||
out = self.ReLU(self.conv22(out))
|
||||
out = self.maxpool2(out)
|
||||
out = self.ReLU(self.conv31(out))
|
||||
out = self.ReLU(self.conv32(out))
|
||||
out = self.ReLU(self.conv33(out))
|
||||
out = self.maxpool3(out)
|
||||
out = self.ReLU(self.conv41(out))
|
||||
out = self.ReLU(self.conv42(out))
|
||||
out = self.ReLU(self.conv43(out))
|
||||
out = self.maxpool4(out)
|
||||
out = self.ReLU(self.conv51(out))
|
||||
out = self.ReLU(self.conv52(out))
|
||||
out = self.conv53(out)
|
||||
return out
|
||||
|
||||
|
||||
class MINCFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
|
||||
device=torch.device('cpu')):
|
||||
super(MINCFeatureExtractor, self).__init__()
|
||||
|
||||
self.features = MINCNet()
|
||||
self.features.load_state_dict(
|
||||
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
|
||||
self.features.eval()
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
|
@ -1,29 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define GAN loss: [vanilla | lsgan]
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super(GANLoss, self).__init__()
|
||||
self.gan_type = gan_type.lower()
|
||||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
|
||||
if self.gan_type == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||
|
||||
def get_target_label(self, input, target_is_real):
|
||||
if target_is_real:
|
||||
return torch.empty_like(input).fill_(self.real_label_val)
|
||||
else:
|
||||
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||
|
||||
def forward(self, input, target_is_real):
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input, target_label)
|
||||
return loss
|
|
@ -1,94 +0,0 @@
|
|||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def _get_random_crop_indices(crop_region, crop_size):
|
||||
'''
|
||||
crop_region: (strat_y, end_y, start_x, end_x)
|
||||
crop_size: (y, x)
|
||||
'''
|
||||
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
|
||||
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
|
||||
print(region_size, crop_size)
|
||||
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
|
||||
if region_size[0] == crop_size[0]:
|
||||
start_y = crop_region[0]
|
||||
else:
|
||||
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
|
||||
if region_size[1] == crop_size[1]:
|
||||
start_x = crop_region[2]
|
||||
else:
|
||||
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
|
||||
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
|
||||
|
||||
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
|
||||
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
|
||||
max_choice = candidates[0]
|
||||
min_choice = candidates[0]
|
||||
max_dist = 0
|
||||
min_dist = np.infty
|
||||
with torch.no_grad():
|
||||
for c in candidates:
|
||||
start_y, end_y, start_x, end_x = c
|
||||
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
|
||||
if dist > max_dist:
|
||||
max_dist = dist
|
||||
max_choice = c
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_choice = c
|
||||
if min_diff:
|
||||
return min_choice
|
||||
else:
|
||||
return max_choice
|
||||
|
||||
def get_split_list(divisor, dividend):
|
||||
split_list = [dividend // divisor for _ in range(divisor - 1)]
|
||||
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
|
||||
return split_list
|
||||
|
||||
def random_sampler(pic_size, crop_dict):
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
def region_sampler(crop_region, crop_dict):
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
# TODO more flexible
|
||||
def pyramid_sampler(pic_size, crop_dict):
|
||||
crop_res_dict = {}
|
||||
sorted_key = list(crop_dict.keys())
|
||||
sorted_key.sort(key=lambda x: int(x), reverse=True)
|
||||
k = sorted_key[0]
|
||||
crop_size = (int(k), int(k))
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
|
||||
|
||||
for i in range(1, len(sorted_key)):
|
||||
crop_res_dict[sorted_key[i]] = []
|
||||
afore_num = crop_dict[sorted_key[i-1]]
|
||||
new_num = crop_dict[sorted_key[i]]
|
||||
split_list = get_split_list(afore_num, new_num)
|
||||
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
|
||||
for j in range(len(split_list)):
|
||||
crop_region = crop_res_dict[sorted_key[i-1]][j]
|
||||
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
|
||||
|
||||
return crop_res_dict
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
import functools
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
import models.SPSR_modules.architecture as arch
|
||||
logger = logging.getLogger('base')
|
||||
####################
|
||||
# initialize
|
||||
####################
|
||||
|
||||
|
||||
def weights_init_normal(m, std=0.02):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.normal_(m.weight.data, 0.0, std)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.normal_(m.weight.data, 0.0, std)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def weights_init_kaiming(m, scale=1):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
if m.affine != False:
|
||||
|
||||
init.constant_(m.weight.data, 1.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def weights_init_orthogonal(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.orthogonal_(m.weight.data, gain=1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.orthogonal_(m.weight.data, gain=1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
init.constant_(m.weight.data, 1.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
|
||||
# scale for 'kaiming', std for 'normal'.
|
||||
if init_type == 'normal':
|
||||
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
|
||||
net.apply(weights_init_normal_)
|
||||
elif init_type == 'kaiming':
|
||||
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
|
||||
net.apply(weights_init_kaiming_)
|
||||
elif init_type == 'orthogonal':
|
||||
net.apply(weights_init_orthogonal)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
|
||||
|
||||
|
||||
####################
|
||||
# define network
|
||||
####################
|
||||
|
||||
|
||||
# Generator
|
||||
def define_G(opt, device=None):
|
||||
opt_net = opt['network_G']
|
||||
which_model = opt_net['which_model_G']
|
||||
|
||||
if which_model == 'spsr_net':
|
||||
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
if opt['is_train']:
|
||||
init_weights(netG, init_type='kaiming', scale=0.1)
|
||||
|
||||
return netG
|
||||
|
||||
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
opt_net = opt['network_D']
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
|
||||
elif which_model == 'discriminator_vgg_96':
|
||||
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_192':
|
||||
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_128_SN':
|
||||
netD = arch.Discriminator_VGG_128_SN()
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
init_weights(netD, init_type='kaiming', scale=1)
|
||||
|
||||
return netD
|
||||
|
||||
def define_D_grad(opt):
|
||||
opt_net = opt['network_D']
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
|
||||
elif which_model == 'discriminator_vgg_96':
|
||||
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_192':
|
||||
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_128_SN':
|
||||
netD = arch.Discriminator_VGG_128_SN()
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
|
||||
init_weights(netD, init_type='kaiming', scale=1)
|
||||
|
||||
return netD
|
||||
|
||||
|
||||
def define_F(opt, use_bn=False):
|
||||
device = torch.device('cuda')
|
||||
# pytorch pretrained VGG19-54, before ReLU.
|
||||
if use_bn:
|
||||
feature_layer = 49
|
||||
else:
|
||||
feature_layer = 34
|
||||
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
|
||||
use_input_norm=True, device=device)
|
||||
|
||||
netF.eval()
|
||||
return netF
|
|
@ -9,6 +9,7 @@ from models.base_model import BaseModel
|
|||
from models.loss import GANLoss, FDPLLoss
|
||||
from apex import amp
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
from .archs.SPSR_arch import ImageGradient, ImageGradientNoPadding
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
import random
|
||||
|
@ -27,11 +28,18 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
self.rank = -1 # non dist training
|
||||
train_opt = opt['train']
|
||||
self.spsr_enabled = 'spsr' in opt['model']
|
||||
|
||||
# Only pixgan and gan are currently supported in spsr_mode
|
||||
if self.spsr_enabled:
|
||||
assert train_opt['gan_type'] == 'pixgan' or train_opt['gan_type'] == 'gan'
|
||||
|
||||
# define networks and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device)
|
||||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device)
|
||||
if self.spsr_enabled:
|
||||
self.netD_grad = networks.define_D(opt).to(self.device) # D_grad
|
||||
|
||||
if 'network_C' in opt.keys():
|
||||
self.netC = networks.define_G(opt, net_key='network_C').to(self.device)
|
||||
|
@ -73,6 +81,33 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
self.fdpl_enabled = False
|
||||
|
||||
if self.spsr_enabled:
|
||||
spsr_opt = train_opt['spsr']
|
||||
self.branch_pretrain = spsr_opt['branch_pretrain'] if spsr_opt['branch_pretrain'] else 0
|
||||
self.branch_init_iters = spsr_opt['branch_init_iters'] if spsr_opt['branch_init_iters'] else 1
|
||||
if spsr_opt['gradient_pixel_weight'] > 0:
|
||||
self.cri_pix_grad = nn.MSELoss().to(self.device)
|
||||
self.l_pix_grad_w = spsr_opt['gradient_pixel_weight']
|
||||
else:
|
||||
self.cri_pix_grad = None
|
||||
if spsr_opt['gradient_gan_weight'] > 0:
|
||||
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_grad_w = spsr_opt['gradient_gan_weight']
|
||||
else:
|
||||
self.cri_grad_gan = None
|
||||
if spsr_opt['pixel_branch_weight'] > 0:
|
||||
l_pix_type = spsr_opt['pixel_branch_criterion']
|
||||
if l_pix_type == 'l1':
|
||||
self.cri_pix_branch = nn.L1Loss().to(self.device)
|
||||
elif l_pix_type == 'l2':
|
||||
self.cri_pix_branch = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||
self.l_pix_branch_w = spsr_opt['pixel_branch_weight']
|
||||
else:
|
||||
logger.info('Remove G_grad pixel loss.')
|
||||
self.cri_pix_branch = None
|
||||
|
||||
# G feature loss
|
||||
if train_opt['feature_weight'] and train_opt['feature_weight'] > 0:
|
||||
# For backwards compatibility, use a scheduler definition instead. Remove this at some point.
|
||||
|
@ -139,7 +174,7 @@ class SRGANModel(BaseModel):
|
|||
self.corruptor_usage_prob = train_opt['corruptor_usage_probability'] if train_opt['corruptor_usage_probability'] else .5
|
||||
|
||||
# optimizers
|
||||
# G
|
||||
# G optimizer
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
optim_params = []
|
||||
if train_opt['lr_scheme'] == 'ProgressiveMultiStepLR':
|
||||
|
@ -155,6 +190,7 @@ class SRGANModel(BaseModel):
|
|||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
# D optimizer
|
||||
optim_params = []
|
||||
for k, v in self.netD.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
|
@ -162,16 +198,40 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
# D
|
||||
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
|
||||
weight_decay=wd_D,
|
||||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
# AMP
|
||||
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
||||
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
||||
if self.spsr_enabled:
|
||||
# D_grad optimizer
|
||||
optim_params = []
|
||||
for k, v in self.netD_grad.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
# D
|
||||
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D_grad = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
|
||||
weight_decay=wd_D,
|
||||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
if self.spsr_enabled:
|
||||
self.get_grad = ImageGradient().to(self.device)
|
||||
self.get_grad_nopadding = ImageGradientNoPadding().to(self.device)
|
||||
[self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding], \
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
||||
amp.initialize([self.netG, self.netD, self.netD_grad, self.get_grad, self.get_grad_nopadding],
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
||||
opt_level=self.amp_level, num_losses=3)
|
||||
else:
|
||||
# AMP
|
||||
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
||||
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
||||
|
||||
# DataParallel
|
||||
if opt['dist']:
|
||||
|
@ -188,6 +248,8 @@ class SRGANModel(BaseModel):
|
|||
self.netD = DataParallel(self.netD)
|
||||
self.netG.train()
|
||||
self.netD.train()
|
||||
if self.spsr_enabled:
|
||||
self.netD_grad.train()
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
|
@ -208,6 +270,10 @@ class SRGANModel(BaseModel):
|
|||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D, train_opt['disc_lr_steps'],
|
||||
[0],
|
||||
train_opt['lr_gamma']))
|
||||
if self.spsr_enabled:
|
||||
self.schedulers.append(lr_scheduler.ProgressiveMultiStepLR(self.optimizer_D_grad, train_opt['disc_lr_steps'],
|
||||
[0],
|
||||
train_opt['lr_gamma']))
|
||||
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(
|
||||
|
@ -284,18 +350,22 @@ class SRGANModel(BaseModel):
|
|||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if step >= self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
if self.spsr_enabled:
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.swapout_D(step)
|
||||
self.swapout_G(step)
|
||||
|
||||
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
for p in self.netG.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
p.requires_grad = True
|
||||
if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters:
|
||||
for k, v in self.netG.named_parameters():
|
||||
v.requires_grad = '_branch_pretrain' in k
|
||||
else:
|
||||
for p in self.netG.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
for p in self.netG.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -310,17 +380,32 @@ class SRGANModel(BaseModel):
|
|||
print("Misc setup %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
if step >= self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_GenOut = []
|
||||
self.fea_GenOut = []
|
||||
self.fake_H = []
|
||||
self.spsr_grad_GenOut = []
|
||||
var_ref_skips = []
|
||||
for var_L, var_LGAN, var_H, var_ref, pix in zip(self.var_L, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_L)
|
||||
if self.spsr_enabled:
|
||||
# SPSR models have outputs from three different branches.
|
||||
fake_H_branch, fake_GenOut, grad_LR = self.netG(var_L)
|
||||
fea_GenOut = fake_GenOut
|
||||
using_gan_img = False
|
||||
# Get image gradients for later use.
|
||||
fake_H_grad = self.get_grad(fake_GenOut)
|
||||
var_H_grad = self.get_grad(var_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
self.spsr_grad_GenOut.append(grad_LR)
|
||||
else:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_LGAN)
|
||||
using_gan_img = True
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_L)
|
||||
using_gan_img = False
|
||||
else:
|
||||
fea_GenOut, fake_GenOut = self.netG(var_LGAN)
|
||||
using_gan_img = True
|
||||
|
||||
if _profile:
|
||||
print("Gen forward %f" % (time() - _t,))
|
||||
|
@ -339,6 +424,13 @@ class SRGANModel(BaseModel):
|
|||
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad)
|
||||
l_g_total += l_g_pix_grad
|
||||
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch,
|
||||
var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
if self.fdpl_enabled and not using_gan_img:
|
||||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
l_g_total += l_g_fdpl * self.fdpl_weight
|
||||
|
@ -370,6 +462,7 @@ class SRGANModel(BaseModel):
|
|||
l_g_fix_disc = l_g_fix_disc + weight * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fix_disc
|
||||
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
|
@ -383,6 +476,14 @@ class SRGANModel(BaseModel):
|
|||
l_g_gan_log = l_g_gan / self.l_gan_w
|
||||
l_g_total += l_g_gan
|
||||
|
||||
if self.spsr_enabled and self.cri_grad_gan: # grad G gan + cls loss
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
# Scale the loss down by the batch factor.
|
||||
l_g_total_log = l_g_total
|
||||
l_g_total = l_g_total / self.mega_batch_factor
|
||||
|
@ -418,8 +519,10 @@ class SRGANModel(BaseModel):
|
|||
gen_input = var_LGAN
|
||||
# Re-compute generator outputs (post-update).
|
||||
with torch.no_grad():
|
||||
_, fake_H = self.netG(gen_input)
|
||||
# The following line detaches all generator outputs that are not None.
|
||||
if self.spsr_enabled:
|
||||
_, fake_H, _ = self.netG(gen_input)
|
||||
else:
|
||||
_, fake_H = self.netG(gen_input)
|
||||
fake_H = fake_H.detach()
|
||||
|
||||
if _profile:
|
||||
|
@ -546,11 +649,36 @@ class SRGANModel(BaseModel):
|
|||
self.fake_H.append(fake_H.detach())
|
||||
self.optimizer_D.step()
|
||||
|
||||
|
||||
if _profile:
|
||||
print("Disc step %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# D_grad.
|
||||
if self.spsr_enabled and self.cri_grad_gan and step >= self.G_warmup:
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
fake_H_grad = self.get_grad(fake_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
elif self.opt['train']['gan_type'] == 'pixgan':
|
||||
real = torch.ones_like(pred_d_real_grad)
|
||||
fake = torch.zeros_like(pred_d_fake_grad)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), real)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), fake)
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||
l_d_total_grad_scaled.backward()
|
||||
self.optimizer_D_grad.step()
|
||||
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % self.img_debug_steps == 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
|
@ -562,6 +690,8 @@ class SRGANModel(BaseModel):
|
|||
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
||||
if self.spsr_enabled:
|
||||
os.makedirs(os.path.join(sample_save_path, "gen_grad"), exist_ok=True)
|
||||
|
||||
# fed_LQ is not chunked.
|
||||
for i in range(self.mega_batch_factor):
|
||||
|
@ -570,6 +700,8 @@ class SRGANModel(BaseModel):
|
|||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
|
||||
if self.spsr_enabled:
|
||||
utils.save_image(self.spsr_grad_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_grad", "%05i_%02i.png" % (step, i)))
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup and 'pixgan' in self.opt['train']['gan_type']:
|
||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
||||
|
@ -594,11 +726,19 @@ class SRGANModel(BaseModel):
|
|||
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||
if self.spsr_enabled:
|
||||
if self.cri_pix_branch:
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad_branch.item())
|
||||
if self.l_gan_w > 0 and step >= self.G_warmup:
|
||||
self.add_log_entry('l_d_real', l_d_real_log.item())
|
||||
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
||||
if self.spsr_enabled:
|
||||
self.add_log_entry('l_d_real_grad', l_d_real_grad.item())
|
||||
self.add_log_entry('l_d_fake_grad', l_d_fake_grad.item())
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake_grad.detach()))
|
||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake_grad) - torch.mean(pred_d_real_grad))
|
||||
|
||||
# Log learning rates.
|
||||
for i, pg in enumerate(self.optimizer_G.param_groups):
|
||||
|
@ -685,7 +825,16 @@ class SRGANModel(BaseModel):
|
|||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_GenOut = [self.netG(self.var_L[0])]
|
||||
if self.spsr_enabled:
|
||||
self.fake_H_branch = []
|
||||
self.fake_GenOut = []
|
||||
self.grad_LR = []
|
||||
fake_H_branch, fake_GenOut, grad_LR = self.netG(self.var_L[0])
|
||||
self.fake_H_branch.append(fake_H_branch)
|
||||
self.fake_GenOut.append(fake_GenOut)
|
||||
self.grad_LR.append(grad_LR)
|
||||
else:
|
||||
self.fake_GenOut = [self.netG(self.var_L[0])]
|
||||
self.netG.train()
|
||||
|
||||
# Fetches a summary of the log.
|
||||
|
@ -713,6 +862,9 @@ class SRGANModel(BaseModel):
|
|||
out_dict['rlt'] = gen_batch.detach().float().cpu()
|
||||
if need_GT:
|
||||
out_dict['GT'] = self.var_H[0].detach().float().cpu()
|
||||
if self.spsr_enabled:
|
||||
out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu()
|
||||
out_dict['LR_grad'] = self.grad_LR[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
|
@ -762,6 +914,11 @@ class SRGANModel(BaseModel):
|
|||
if self.opt['is_train'] and load_path_D is not None:
|
||||
logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
|
||||
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
|
||||
if self.spsr_enabled:
|
||||
load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
|
||||
if self.opt['is_train'] and load_path_D_grad is not None:
|
||||
logger.info('Loading pretrained model for D_grad [{:s}] ...'.format(load_path_D_grad))
|
||||
self.load_network(load_path_D_grad, self.netD_grad)
|
||||
|
||||
def load_random_corruptor(self):
|
||||
if self.netC is None:
|
||||
|
@ -774,3 +931,4 @@ class SRGANModel(BaseModel):
|
|||
def save(self, iter_step):
|
||||
self.save_network(self.netG, 'G', iter_step)
|
||||
self.save_network(self.netD, 'D', iter_step)
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||
|
|
|
@ -7,11 +7,11 @@ def create_model(opt):
|
|||
# image restoration
|
||||
if model == 'sr': # PSNR-oriented super resolution
|
||||
from .SR_model import SRModel as M
|
||||
elif model == 'srgan' or model == 'corruptgan': # GAN-based super resolution(SRGAN / ESRGAN), or corruption use same logic
|
||||
elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
|
||||
from .SRGAN_model import SRGANModel as M
|
||||
elif model == 'feat':
|
||||
from .feature_model import FeatureModel as M
|
||||
if model == 'spsr':
|
||||
elif model == 'spsr':
|
||||
from .SPSR_model import SPSRModel as M
|
||||
else:
|
||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||
|
|
226
codes/models/archs/SPSR_arch.py
Normal file
226
codes/models/archs/SPSR_arch.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.archs import SPSR_util as B
|
||||
from .RRDBNet_arch import RRDB
|
||||
|
||||
|
||||
class ImageGradient(nn.Module):
|
||||
def __init__(self):
|
||||
super(ImageGradient, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class ImageGradientNoPadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(ImageGradientNoPadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
||||
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x_list = []
|
||||
for i in range(x.shape[1]):
|
||||
x_i = x[:, i]
|
||||
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
||||
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
||||
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
||||
x_list.append(x_i)
|
||||
|
||||
x = torch.cat(x_list, dim = 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Generator
|
||||
####################
|
||||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
|
||||
|
||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
||||
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_1 = RRDB(nf*2, gc=32)
|
||||
|
||||
|
||||
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_2 = RRDB(nf*2, gc=32)
|
||||
|
||||
|
||||
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_3 = RRDB(nf*2, gc=32)
|
||||
|
||||
|
||||
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_4 = RRDB(nf*2, gc=32)
|
||||
|
||||
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||
|
||||
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
|
||||
|
||||
# Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest.
|
||||
self._branch_pretrain_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self._branch_pretrain_block = RRDB(nf*2, gc=32)
|
||||
|
||||
self._branch_pretrain_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self._branch_pretrain_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model[0](x)
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
#short cut
|
||||
x = x_ori+x
|
||||
x= self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
|
||||
#short cut
|
||||
x_cat_4 = x_cat_4+x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
|
||||
x_out_branch = self.conv_w(x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat)
|
||||
x_out = self._branch_pretrain_concat(x__branch_pretrain_cat)
|
||||
x_out = self._branch_pretrain_HR_conv0(x_out)
|
||||
x_out = self._branch_pretrain_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
|
@ -5,8 +5,6 @@ import torch.nn as nn
|
|||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
||||
# helper selecting activation
|
||||
# neg_slope: for leakyrelu and init of prelu
|
||||
|
@ -134,101 +132,6 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||
return sequential(n, a, p, c)
|
||||
|
||||
|
||||
####################
|
||||
# Useful blocks
|
||||
####################
|
||||
|
||||
class ResNetBlock(nn.Module):
|
||||
'''
|
||||
ResNet Block, 3-3 style
|
||||
with extra residual scaling used in EDSR
|
||||
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||||
'''
|
||||
|
||||
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
||||
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
||||
super(ResNetBlock, self).__init__()
|
||||
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
if mode == 'CNA':
|
||||
act_type = None
|
||||
if mode == 'CNAC': # Residual path: |-CNAC-|
|
||||
act_type = None
|
||||
norm_type = None
|
||||
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
# if in_nc != out_nc:
|
||||
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
||||
# None, None)
|
||||
# print('Need a projecter in ResNetBlock.')
|
||||
# else:
|
||||
# self.project = lambda x:x
|
||||
self.res = sequential(conv0, conv1)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
res = self.res(x).mul(self.res_scale)
|
||||
return x + res
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
'''
|
||||
Residual Dense Block
|
||||
style: 5 convs
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
'''
|
||||
|
||||
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=last_act, mode=mode)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5.mul(0.2) + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
'''
|
||||
|
||||
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out.mul(0.2) + x
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
|
@ -78,6 +78,65 @@ class Discriminator_VGG_128(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class Discriminator_VGG_128_GN(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf, input_img_factor=1):
|
||||
super(Discriminator_VGG_128_GN, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
final_nf = nf * 8
|
||||
|
||||
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||
self.linear2 = nn.Linear(100, 1)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||
|
||||
#fea = torch.cat([fea, skip_med], dim=1)
|
||||
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||
|
||||
#fea = torch.cat([fea, skip_lo], dim=1)
|
||||
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
|
||||
fea = fea.contiguous().view(fea.size(0), -1)
|
||||
fea = self.lrelu(self.linear1(fea))
|
||||
out = self.linear2(fea)
|
||||
return out
|
||||
|
||||
class Discriminator_VGG_PixLoss(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(Discriminator_VGG_PixLoss, self).__init__()
|
||||
|
|
|
@ -11,6 +11,8 @@ import models.archs.feature_arch as feature_arch
|
|||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.archs.SRG1_arch as srg1
|
||||
import models.archs.ProgressiveSrg_arch as psrg
|
||||
import models.archs.SPSR_arch as spsr
|
||||
import models.archs.arch_util as arch_util
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
|
@ -97,6 +99,12 @@ def define_G(opt, net_key='network_G'):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
start_step=opt_net['start_step'])
|
||||
elif which_model == 'spsr_net':
|
||||
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||||
if opt['is_train']:
|
||||
arch_util.initialize_weights(netG, scale=.1)
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
@ -119,6 +127,8 @@ def define_D_net(opt_net, img_sz=None):
|
|||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
|
||||
elif which_model == 'discriminator_vgg_128_gn':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128)
|
||||
elif which_model == 'discriminator_resnet':
|
||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||
elif which_model == 'discriminator_resnet_passthrough':
|
||||
|
|
|
@ -115,7 +115,11 @@ def check_resume(opt, resume_iter):
|
|||
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
||||
'{}_G.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
||||
if 'gan' in opt['model']:
|
||||
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
||||
'{}_D.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
||||
if 'spsr' in opt['model']:
|
||||
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
||||
'{}_D_grad.pth'.format(resume_iter))
|
||||
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -215,7 +215,7 @@ def main():
|
|||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
|
|
Loading…
Reference in New Issue
Block a user