Add SPSR_module
This is a port from the SPSR repo, it's going to need a lot of work to be properly integrated but as of this commit it at least runs.
This commit is contained in:
parent
f33ed578a2
commit
f894ba8f98
462
codes/models/SPSR_model.py
Normal file
462
codes/models/SPSR_model.py
Normal file
|
@ -0,0 +1,462 @@
|
|||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
import models.SPSR_networks as networks
|
||||
from .base_model import BaseModel
|
||||
from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Get_gradient(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=2)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=2)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=2)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=2)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
class Get_gradient_nopadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient_nopadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False).cuda()
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
x2 = x[:, 2]
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding = 1)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding = 1)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class SPSRModel(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(SPSRModel, self).__init__(opt)
|
||||
train_opt = opt['train']
|
||||
|
||||
# define networks and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device) # G
|
||||
if self.is_train:
|
||||
self.netD = networks.define_D(opt).to(self.device) # D
|
||||
self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
|
||||
self.netG.train()
|
||||
self.netD.train()
|
||||
self.netD_grad.train()
|
||||
self.load() # load G and D if needed
|
||||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
# G pixel loss
|
||||
if train_opt['pixel_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_criterion']
|
||||
if l_pix_type == 'l1':
|
||||
self.cri_pix = nn.L1Loss().to(self.device)
|
||||
elif l_pix_type == 'l2':
|
||||
self.cri_pix = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||
self.l_pix_w = train_opt['pixel_weight']
|
||||
else:
|
||||
logger.info('Remove pixel loss.')
|
||||
self.cri_pix = None
|
||||
|
||||
# G feature loss
|
||||
if train_opt['feature_weight'] > 0:
|
||||
l_fea_type = train_opt['feature_criterion']
|
||||
if l_fea_type == 'l1':
|
||||
self.cri_fea = nn.L1Loss().to(self.device)
|
||||
elif l_fea_type == 'l2':
|
||||
self.cri_fea = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||
self.l_fea_w = train_opt['feature_weight']
|
||||
else:
|
||||
logger.info('Remove feature loss.')
|
||||
self.cri_fea = None
|
||||
if self.cri_fea: # load VGG perceptual loss
|
||||
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||
|
||||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_w = train_opt['gan_weight']
|
||||
# D_update_ratio and D_init_iters are for WGAN
|
||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
# Branch_init_iters
|
||||
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
|
||||
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
|
||||
|
||||
if train_opt['gan_type'] == 'wgan-gp':
|
||||
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
||||
# gradient penalty loss
|
||||
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
||||
self.l_gp_w = train_opt['gp_weigth']
|
||||
|
||||
# gradient_pixel_loss
|
||||
if train_opt['gradient_pixel_weight'] > 0:
|
||||
self.cri_pix_grad = nn.MSELoss().to(self.device)
|
||||
self.l_pix_grad_w = train_opt['gradient_pixel_weight']
|
||||
else:
|
||||
self.cri_pix_grad = None
|
||||
|
||||
# gradient_gan_loss
|
||||
if train_opt['gradient_gan_weight'] > 0:
|
||||
self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_grad_w = train_opt['gradient_gan_weight']
|
||||
else:
|
||||
self.cri_grad_gan = None
|
||||
|
||||
# G_grad pixel loss
|
||||
if train_opt['pixel_branch_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_branch_criterion']
|
||||
if l_pix_type == 'l1':
|
||||
self.cri_pix_branch = nn.L1Loss().to(self.device)
|
||||
elif l_pix_type == 'l2':
|
||||
self.cri_pix_branch = nn.MSELoss().to(self.device)
|
||||
else:
|
||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||||
self.l_pix_branch_w = train_opt['pixel_branch_weight']
|
||||
else:
|
||||
logger.info('Remove G_grad pixel loss.')
|
||||
self.cri_pix_branch = None
|
||||
|
||||
# optimizers
|
||||
# G
|
||||
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
||||
|
||||
optim_params = []
|
||||
for k, v in self.netG.named_parameters(): # optimize part of the model
|
||||
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
|
||||
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
# D
|
||||
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
||||
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
# D_grad
|
||||
wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
|
||||
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
||||
|
||||
self.optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
||||
train_opt['lr_steps'], train_opt['lr_gamma']))
|
||||
else:
|
||||
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
||||
|
||||
self.log_dict = OrderedDict()
|
||||
self.get_grad = Get_gradient()
|
||||
self.get_grad_nopadding = Get_gradient_nopadding()
|
||||
|
||||
def feed_data(self, data, need_HR=True):
|
||||
# LR
|
||||
self.var_L = data['LQ'].to(self.device)
|
||||
|
||||
if need_HR: # train or val
|
||||
self.var_H = data['GT'].to(self.device)
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.var_ref = input_ref.to(self.device)
|
||||
|
||||
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
if(self.Branch_pretrain):
|
||||
if(step < self.Branch_init_iters):
|
||||
for k,v in self.netG.named_parameters():
|
||||
if 'f_' not in k :
|
||||
v.requires_grad=False
|
||||
else:
|
||||
for k,v in self.netG.named_parameters():
|
||||
if 'f_' not in k :
|
||||
v.requires_grad=True
|
||||
|
||||
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||
|
||||
|
||||
self.fake_H_grad = self.get_grad(self.fake_H)
|
||||
self.var_H_grad = self.get_grad(self.var_H)
|
||||
self.var_ref_grad = self.get_grad(self.var_ref)
|
||||
self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
|
||||
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(self.var_H).detach()
|
||||
fake_fea = self.netF(self.fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if self.cri_pix_grad: #gradient pixel loss
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
|
||||
l_g_total += l_g_pix_grad
|
||||
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
|
||||
|
||||
# G gan + cls loss
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
pred_d_real = self.netD(self.var_ref).detach()
|
||||
|
||||
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
# grad G gan + cls loss
|
||||
|
||||
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
|
||||
l_g_total.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
l_d_total = 0
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
batch_size = self.var_ref.size(0)
|
||||
if self.random_pt.size(0) != batch_size:
|
||||
self.random_pt.resize_(batch_size, 1, 1, 1)
|
||||
self.random_pt.uniform_() # Draw random interpolation points
|
||||
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
|
||||
interp.requires_grad = True
|
||||
interp_crit, _ = self.netD(interp)
|
||||
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
|
||||
l_d_total += l_d_gp
|
||||
|
||||
l_d_total.backward()
|
||||
|
||||
self.optimizer_D.step()
|
||||
|
||||
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
l_d_total_grad = 0
|
||||
|
||||
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G
|
||||
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
|
||||
|
||||
l_d_total_grad.backward()
|
||||
|
||||
self.optimizer_D_grad.step()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
import torchvision.utils as utils
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||
|
||||
|
||||
# set log
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
# G
|
||||
if self.cri_pix:
|
||||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
if self.cri_fea:
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
|
||||
|
||||
# D
|
||||
self.log_dict['l_d_real'] = l_d_real.item()
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
|
||||
# D_grad
|
||||
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
|
||||
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
|
||||
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
self.log_dict['l_d_gp'] = l_d_gp.item()
|
||||
# D outputs
|
||||
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
|
||||
# D_grad outputs
|
||||
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
|
||||
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self, step):
|
||||
return self.log_dict
|
||||
|
||||
def get_current_visuals(self, need_HR=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
||||
|
||||
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
||||
out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu()
|
||||
out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu()
|
||||
if need_HR:
|
||||
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
# Generator
|
||||
s, n = self.get_network_description(self.netG)
|
||||
if isinstance(self.netG, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
||||
self.netG.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
||||
|
||||
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
if self.is_train:
|
||||
# Disriminator
|
||||
s, n = self.get_network_description(self.netD)
|
||||
if isinstance(self.netD, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
||||
self.netD.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
||||
|
||||
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
if self.cri_fea: # F, Perceptual Network
|
||||
s, n = self.get_network_description(self.netF)
|
||||
if isinstance(self.netF, nn.DataParallel):
|
||||
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
||||
self.netF.module.__class__.__name__)
|
||||
else:
|
||||
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
||||
|
||||
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
||||
logger.info(s)
|
||||
|
||||
def load(self):
|
||||
load_path_G = self.opt['path']['pretrain_model_G']
|
||||
if load_path_G is not None:
|
||||
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
||||
self.load_network(load_path_G, self.netG)
|
||||
load_path_D = self.opt['path']['pretrain_model_D']
|
||||
if self.opt['is_train'] and load_path_D is not None:
|
||||
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
||||
self.load_network(load_path_D, self.netD)
|
||||
|
||||
def save(self, iter_step):
|
||||
self.save_network(self.netG, 'G', iter_step)
|
||||
self.save_network(self.netD, 'D', iter_step)
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
0
codes/models/SPSR_modules/__init__.py
Normal file
0
codes/models/SPSR_modules/__init__.py
Normal file
654
codes/models/SPSR_modules/architecture.py
Normal file
654
codes/models/SPSR_modules/architecture.py
Normal file
|
@ -0,0 +1,654 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from . import block as B
|
||||
from . import spectral_norm as SN
|
||||
|
||||
|
||||
|
||||
class Get_gradient_nopadding(nn.Module):
|
||||
def __init__(self):
|
||||
super(Get_gradient_nopadding, self).__init__()
|
||||
kernel_v = [[0, -1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0],
|
||||
[-1, 0, 1],
|
||||
[0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False)
|
||||
|
||||
self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x_list = []
|
||||
for i in range(x.shape[1]):
|
||||
x_i = x[:, i]
|
||||
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
||||
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
||||
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
||||
x_list.append(x_i)
|
||||
|
||||
x = torch.cat(x_list, dim = 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Generator
|
||||
####################
|
||||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
||||
|
||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
self.HR_conv0_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self.HR_conv1_new = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.get_g_nopadding = Get_gradient_nopadding()
|
||||
|
||||
self.b_fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_concat_1 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_1 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_2 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_2 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_3 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_3 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
|
||||
self.b_concat_4 = B.conv_block(2*nf, nf, kernel_size=3, norm_type=None, act_type = None)
|
||||
self.b_block_4 = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
self.b_LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = B.upconv_blcok
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = B.pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
b_upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
b_upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
b_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
b_HR_conv1 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||
|
||||
self.conv_w = B.conv_block(nf, out_nc, kernel_size=1, norm_type=None, act_type=None)
|
||||
|
||||
self.f_concat = B.conv_block(nf*2, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.f_block = B.RRDB(nf*2, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA')
|
||||
|
||||
self.f_HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
self.f_HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model[0](x)
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i+15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
#short cut
|
||||
x = x_ori+x
|
||||
x= self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
|
||||
#short cut
|
||||
x_cat_4 = x_cat_4+x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
|
||||
x_out_branch = self.conv_w(x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x_f_cat = self.f_block(x_f_cat)
|
||||
x_out = self.f_concat(x_f_cat)
|
||||
x_out = self.f_HR_conv0(x_out)
|
||||
x_out = self.f_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
####################
|
||||
# Discriminator
|
||||
####################
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 128*128
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_128, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 128, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 64, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 32, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 96*96
|
||||
class Discriminator_VGG_96(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_96, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 96, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 64*64
|
||||
class Discriminator_VGG_64(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_64, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 64, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 32, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 32*32
|
||||
class Discriminator_VGG_32(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_32, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 32, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 256
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 16*16
|
||||
class Discriminator_VGG_16(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_16, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 16, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 128
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 128*128, Spectral Normalization
|
||||
class Discriminator_VGG_128_SN(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator_VGG_128_SN, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 128, 64
|
||||
self.lrelu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
||||
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
||||
# 64, 64
|
||||
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
||||
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
||||
# 32, 128
|
||||
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
||||
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
||||
# 16, 256
|
||||
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
||||
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||
# 8, 512
|
||||
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
||||
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||
# 4, 512
|
||||
|
||||
# classifier
|
||||
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
||||
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.lrelu(self.conv0(x))
|
||||
x = self.lrelu(self.conv1(x))
|
||||
x = self.lrelu(self.conv2(x))
|
||||
x = self.lrelu(self.conv3(x))
|
||||
x = self.lrelu(self.conv4(x))
|
||||
x = self.lrelu(self.conv5(x))
|
||||
x = self.lrelu(self.conv6(x))
|
||||
x = self.lrelu(self.conv7(x))
|
||||
x = self.lrelu(self.conv8(x))
|
||||
x = self.lrelu(self.conv9(x))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.lrelu(self.linear0(x))
|
||||
x = self.linear1(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator_VGG_96(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_96, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 96, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 3, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator_VGG_192(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_192, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 192, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 96, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 3, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9, conv10, conv11)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Perceptual Network
|
||||
####################
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self,
|
||||
feature_layer=34,
|
||||
use_bn=False,
|
||||
use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
if use_bn:
|
||||
model = torchvision.models.vgg19_bn(pretrained=True)
|
||||
else:
|
||||
model = torchvision.models.vgg19(pretrained=True)
|
||||
self.use_input_norm = use_input_norm
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class ResNet101FeatureExtractor(nn.Module):
|
||||
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
||||
super(ResNet101FeatureExtractor, self).__init__()
|
||||
model = torchvision.models.resnet101(pretrained=True)
|
||||
self.use_input_norm = use_input_norm
|
||||
if self.use_input_norm:
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
||||
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
||||
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.children())[:8])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class MINCNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MINCNet, self).__init__()
|
||||
self.ReLU = nn.ReLU(True)
|
||||
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
|
||||
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
|
||||
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
|
||||
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
|
||||
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
|
||||
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
|
||||
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
|
||||
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.ReLU(self.conv11(x))
|
||||
out = self.ReLU(self.conv12(out))
|
||||
out = self.maxpool1(out)
|
||||
out = self.ReLU(self.conv21(out))
|
||||
out = self.ReLU(self.conv22(out))
|
||||
out = self.maxpool2(out)
|
||||
out = self.ReLU(self.conv31(out))
|
||||
out = self.ReLU(self.conv32(out))
|
||||
out = self.ReLU(self.conv33(out))
|
||||
out = self.maxpool3(out)
|
||||
out = self.ReLU(self.conv41(out))
|
||||
out = self.ReLU(self.conv42(out))
|
||||
out = self.ReLU(self.conv43(out))
|
||||
out = self.maxpool4(out)
|
||||
out = self.ReLU(self.conv51(out))
|
||||
out = self.ReLU(self.conv52(out))
|
||||
out = self.conv53(out)
|
||||
return out
|
||||
|
||||
|
||||
class MINCFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
|
||||
device=torch.device('cpu')):
|
||||
super(MINCFeatureExtractor, self).__init__()
|
||||
|
||||
self.features = MINCNet()
|
||||
self.features.load_state_dict(
|
||||
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
|
||||
self.features.eval()
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
258
codes/models/SPSR_modules/block.py
Normal file
258
codes/models/SPSR_modules/block.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
||||
# helper selecting activation
|
||||
# neg_slope: for leakyrelu and init of prelu
|
||||
# n_prelu: for p_relu num_parameters
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type == 'leakyrelu':
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
def norm(norm_type, nc):
|
||||
# helper selecting normalization layer
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
def pad(pad_type, padding):
|
||||
# helper selecting padding layer
|
||||
# if padding is 'zero', do by conv layers
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ConcatBlock(nn.Module):
|
||||
# Concat the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ConcatBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = torch.cat((x, self.sub(x)), dim=1)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = 'Identity .. \n|'
|
||||
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
#Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
return x, self.sub
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = 'Identity + \n|'
|
||||
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
# Flatten Sequential. It unwraps nn.Sequential.
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
||||
'''
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
'''
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
# Important!
|
||||
# input----ReLU(inplace)----Conv--+----output
|
||||
# |________________________|
|
||||
# inplace ReLU will modify the input, therefore wrong output
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
|
||||
####################
|
||||
# Useful blocks
|
||||
####################
|
||||
|
||||
class ResNetBlock(nn.Module):
|
||||
'''
|
||||
ResNet Block, 3-3 style
|
||||
with extra residual scaling used in EDSR
|
||||
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||||
'''
|
||||
|
||||
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
||||
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
||||
super(ResNetBlock, self).__init__()
|
||||
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
if mode == 'CNA':
|
||||
act_type = None
|
||||
if mode == 'CNAC': # Residual path: |-CNAC-|
|
||||
act_type = None
|
||||
norm_type = None
|
||||
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
# if in_nc != out_nc:
|
||||
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
||||
# None, None)
|
||||
# print('Need a projecter in ResNetBlock.')
|
||||
# else:
|
||||
# self.project = lambda x:x
|
||||
self.res = sequential(conv0, conv1)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
res = self.res(x).mul(self.res_scale)
|
||||
return x + res
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
'''
|
||||
Residual Dense Block
|
||||
style: 5 convs
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
'''
|
||||
|
||||
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=act_type, mode=mode)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
||||
norm_type=norm_type, act_type=last_act, mode=mode)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5.mul(0.2) + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
'''
|
||||
|
||||
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||||
norm_type, act_type, mode)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out.mul(0.2) + x
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu'):
|
||||
'''
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
'''
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
|
||||
pad_type=pad_type, norm_type=None, act_type=None)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
||||
return sequential(upsample, conv)
|
60
codes/models/SPSR_modules/loss.py
Normal file
60
codes/models/SPSR_modules/loss.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super(GANLoss, self).__init__()
|
||||
self.gan_type = gan_type.lower()
|
||||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
|
||||
if self.gan_type == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.gan_type == 'wgan-gp':
|
||||
|
||||
def wgan_loss(input, target):
|
||||
# target is boolean
|
||||
return -1 * input.mean() if target else input.mean()
|
||||
|
||||
self.loss = wgan_loss
|
||||
else:
|
||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||
|
||||
def get_target_label(self, input, target_is_real):
|
||||
if self.gan_type == 'wgan-gp':
|
||||
return target_is_real
|
||||
if target_is_real:
|
||||
return torch.empty_like(input).fill_(self.real_label_val)
|
||||
else:
|
||||
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||
|
||||
def forward(self, input, target_is_real):
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input, target_label)
|
||||
return loss
|
||||
|
||||
|
||||
class GradientPenaltyLoss(nn.Module):
|
||||
def __init__(self, device=torch.device('cpu')):
|
||||
super(GradientPenaltyLoss, self).__init__()
|
||||
self.register_buffer('grad_outputs', torch.Tensor())
|
||||
self.grad_outputs = self.grad_outputs.to(device)
|
||||
|
||||
def get_grad_outputs(self, input):
|
||||
if self.grad_outputs.size() != input.size():
|
||||
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
||||
return self.grad_outputs
|
||||
|
||||
def forward(self, interp, interp_crit):
|
||||
grad_outputs = self.get_grad_outputs(interp_crit)
|
||||
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
|
||||
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
||||
grad_interp_norm = grad_interp.norm(2, dim=1)
|
||||
|
||||
loss = ((grad_interp_norm - 1)**2).mean()
|
||||
return loss
|
94
codes/models/SPSR_modules/sampler.py
Normal file
94
codes/models/SPSR_modules/sampler.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def _get_random_crop_indices(crop_region, crop_size):
|
||||
'''
|
||||
crop_region: (strat_y, end_y, start_x, end_x)
|
||||
crop_size: (y, x)
|
||||
'''
|
||||
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
|
||||
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
|
||||
print(region_size, crop_size)
|
||||
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
|
||||
if region_size[0] == crop_size[0]:
|
||||
start_y = crop_region[0]
|
||||
else:
|
||||
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
|
||||
if region_size[1] == crop_size[1]:
|
||||
start_x = crop_region[2]
|
||||
else:
|
||||
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
|
||||
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
|
||||
|
||||
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
|
||||
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
|
||||
max_choice = candidates[0]
|
||||
min_choice = candidates[0]
|
||||
max_dist = 0
|
||||
min_dist = np.infty
|
||||
with torch.no_grad():
|
||||
for c in candidates:
|
||||
start_y, end_y, start_x, end_x = c
|
||||
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
|
||||
if dist > max_dist:
|
||||
max_dist = dist
|
||||
max_choice = c
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_choice = c
|
||||
if min_diff:
|
||||
return min_choice
|
||||
else:
|
||||
return max_choice
|
||||
|
||||
def get_split_list(divisor, dividend):
|
||||
split_list = [dividend // divisor for _ in range(divisor - 1)]
|
||||
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
|
||||
return split_list
|
||||
|
||||
def random_sampler(pic_size, crop_dict):
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
def region_sampler(crop_region, crop_dict):
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict = {}
|
||||
for k, v in crop_dict.items():
|
||||
crop_size = (int(k), int(k))
|
||||
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
|
||||
return crop_res_dict
|
||||
|
||||
# TODO more flexible
|
||||
def pyramid_sampler(pic_size, crop_dict):
|
||||
crop_res_dict = {}
|
||||
sorted_key = list(crop_dict.keys())
|
||||
sorted_key.sort(key=lambda x: int(x), reverse=True)
|
||||
k = sorted_key[0]
|
||||
crop_size = (int(k), int(k))
|
||||
crop_region = (0, pic_size[0], 0, pic_size[1])
|
||||
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
|
||||
|
||||
for i in range(1, len(sorted_key)):
|
||||
crop_res_dict[sorted_key[i]] = []
|
||||
afore_num = crop_dict[sorted_key[i-1]]
|
||||
new_num = crop_dict[sorted_key[i]]
|
||||
split_list = get_split_list(afore_num, new_num)
|
||||
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
|
||||
for j in range(len(split_list)):
|
||||
crop_region = crop_res_dict[sorted_key[i-1]][j]
|
||||
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
|
||||
|
||||
return crop_res_dict
|
||||
|
149
codes/models/SPSR_modules/spectral_norm.py
Normal file
149
codes/models/SPSR_modules/spectral_norm.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
'''
|
||||
Copy from pytorch github repo
|
||||
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
||||
'''
|
||||
import torch
|
||||
from torch.nn.functional import normalize
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SpectralNorm(object):
|
||||
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||
self.name = name
|
||||
self.dim = dim
|
||||
if n_power_iterations <= 0:
|
||||
raise ValueError('Expected n_power_iterations to be positive, but '
|
||||
'got n_power_iterations={}'.format(n_power_iterations))
|
||||
self.n_power_iterations = n_power_iterations
|
||||
self.eps = eps
|
||||
|
||||
def compute_weight(self, module):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
u = getattr(module, self.name + '_u')
|
||||
weight_mat = weight
|
||||
if self.dim != 0:
|
||||
# permute dim to front
|
||||
weight_mat = weight_mat.permute(self.dim,
|
||||
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
||||
height = weight_mat.size(0)
|
||||
weight_mat = weight_mat.reshape(height, -1)
|
||||
with torch.no_grad():
|
||||
for _ in range(self.n_power_iterations):
|
||||
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
||||
# are the first left and right singular vectors.
|
||||
# This power iteration produces approximations of `u` and `v`.
|
||||
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
||||
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
||||
|
||||
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
||||
weight = weight / sigma
|
||||
return weight, u
|
||||
|
||||
def remove(self, module):
|
||||
weight = getattr(module, self.name)
|
||||
delattr(module, self.name)
|
||||
delattr(module, self.name + '_u')
|
||||
delattr(module, self.name + '_orig')
|
||||
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
||||
|
||||
def __call__(self, module, inputs):
|
||||
if module.training:
|
||||
weight, u = self.compute_weight(module)
|
||||
setattr(module, self.name, weight)
|
||||
setattr(module, self.name + '_u', u)
|
||||
else:
|
||||
r_g = getattr(module, self.name + '_orig').requires_grad
|
||||
getattr(module, self.name).detach_().requires_grad_(r_g)
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name, n_power_iterations, dim, eps):
|
||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||
weight = module._parameters[name]
|
||||
height = weight.size(dim)
|
||||
|
||||
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
||||
delattr(module, fn.name)
|
||||
module.register_parameter(fn.name + "_orig", weight)
|
||||
# We still need to assign weight back as fn.name because all sorts of
|
||||
# things may assume that it exists, e.g., when initializing weights.
|
||||
# However, we can't directly assign as it could be an nn.Parameter and
|
||||
# gets added as a parameter. Instead, we register weight.data as a
|
||||
# buffer, which will cause weight to be included in the state dict
|
||||
# and also supports nn.init due to shared storage.
|
||||
module.register_buffer(fn.name, weight.data)
|
||||
module.register_buffer(fn.name + "_u", u)
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
|
||||
r"""Applies spectral normalization to a parameter in the given module.
|
||||
|
||||
.. math::
|
||||
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
||||
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
||||
|
||||
Spectral normalization stabilizes the training of discriminators (critics)
|
||||
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
||||
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
||||
power iteration method. If the dimension of the weight tensor is greater
|
||||
than 2, it is reshaped to 2D in power iteration method to get spectral
|
||||
norm. This is implemented via a hook that calculates spectral norm and
|
||||
rescales weight before every :meth:`~Module.forward` call.
|
||||
|
||||
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
||||
|
||||
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
||||
|
||||
Args:
|
||||
module (nn.Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
n_power_iterations (int, optional): number of power iterations to
|
||||
calculate spectal norm
|
||||
eps (float, optional): epsilon for numerical stability in
|
||||
calculating norms
|
||||
dim (int, optional): dimension corresponding to number of outputs,
|
||||
the default is 0, except for modules that are instances of
|
||||
ConvTranspose1/2/3d, when it is 1
|
||||
|
||||
Returns:
|
||||
The original module with the spectal norm hook
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = spectral_norm(nn.Linear(20, 40))
|
||||
Linear (20 -> 40)
|
||||
>>> m.weight_u.size()
|
||||
torch.Size([20])
|
||||
|
||||
"""
|
||||
if dim is None:
|
||||
if isinstance(
|
||||
module,
|
||||
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
|
||||
dim = 1
|
||||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
return module
|
||||
|
||||
|
||||
def remove_spectral_norm(module, name='weight'):
|
||||
r"""Removes the spectral normalization reparameterization from a module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
|
||||
Example:
|
||||
>>> m = spectral_norm(nn.Linear(40, 10))
|
||||
>>> remove_spectral_norm(m)
|
||||
"""
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||
hook.remove(module)
|
||||
del module._forward_pre_hooks[k]
|
||||
return module
|
||||
|
||||
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
|
161
codes/models/SPSR_networks.py
Normal file
161
codes/models/SPSR_networks.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import functools
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
import models.SPSR_modules.architecture as arch
|
||||
logger = logging.getLogger('base')
|
||||
####################
|
||||
# initialize
|
||||
####################
|
||||
|
||||
|
||||
def weights_init_normal(m, std=0.02):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.normal_(m.weight.data, 0.0, std)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.normal_(m.weight.data, 0.0, std)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def weights_init_kaiming(m, scale=1):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
if m.affine != False:
|
||||
|
||||
init.constant_(m.weight.data, 1.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def weights_init_orthogonal(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.orthogonal_(m.weight.data, gain=1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('Linear') != -1:
|
||||
init.orthogonal_(m.weight.data, gain=1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
init.constant_(m.weight.data, 1.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
|
||||
# scale for 'kaiming', std for 'normal'.
|
||||
if init_type == 'normal':
|
||||
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
|
||||
net.apply(weights_init_normal_)
|
||||
elif init_type == 'kaiming':
|
||||
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
|
||||
net.apply(weights_init_kaiming_)
|
||||
elif init_type == 'orthogonal':
|
||||
net.apply(weights_init_orthogonal)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
|
||||
|
||||
|
||||
####################
|
||||
# define network
|
||||
####################
|
||||
|
||||
|
||||
# Generator
|
||||
def define_G(opt, device=None):
|
||||
opt_net = opt['network_G']
|
||||
which_model = opt_net['which_model_G']
|
||||
|
||||
if which_model == 'spsr_net':
|
||||
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||||
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
if opt['is_train']:
|
||||
init_weights(netG, init_type='kaiming', scale=0.1)
|
||||
|
||||
return netG
|
||||
|
||||
|
||||
|
||||
# Discriminator
|
||||
def define_D(opt):
|
||||
opt_net = opt['network_D']
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
|
||||
elif which_model == 'discriminator_vgg_96':
|
||||
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_192':
|
||||
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_128_SN':
|
||||
netD = arch.Discriminator_VGG_128_SN()
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
init_weights(netD, init_type='kaiming', scale=1)
|
||||
|
||||
return netD
|
||||
|
||||
def define_D_grad(opt):
|
||||
opt_net = opt['network_D']
|
||||
which_model = opt_net['which_model_D']
|
||||
|
||||
if which_model == 'discriminator_vgg_128':
|
||||
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
|
||||
elif which_model == 'discriminator_vgg_96':
|
||||
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_192':
|
||||
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||||
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||||
elif which_model == 'discriminator_vgg_128_SN':
|
||||
netD = arch.Discriminator_VGG_128_SN()
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
|
||||
|
||||
init_weights(netD, init_type='kaiming', scale=1)
|
||||
|
||||
return netD
|
||||
|
||||
|
||||
def define_F(opt, use_bn=False):
|
||||
device = torch.device('cuda')
|
||||
# pytorch pretrained VGG19-54, before ReLU.
|
||||
if use_bn:
|
||||
feature_layer = 49
|
||||
else:
|
||||
feature_layer = 34
|
||||
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
|
||||
use_input_norm=True, device=device)
|
||||
|
||||
netF.eval()
|
||||
return netF
|
|
@ -11,6 +11,8 @@ def create_model(opt):
|
|||
from .SRGAN_model import SRGANModel as M
|
||||
elif model == 'feat':
|
||||
from .feature_model import FeatureModel as M
|
||||
if model == 'spsr':
|
||||
from .SPSR_model import SPSRModel as M
|
||||
else:
|
||||
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
||||
m = M(opt)
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user