Substantial SPSR mods & fixes

- Added in gradient accumulation via mega-batch-factor
- Added AMP
- Added missing train hooks
- Added debug image outputs
- Cleaned up including removing GradientPenaltyLoss, custom SpectralNorm
- Removed all the custom discriminators
This commit is contained in:
James Betker 2020-08-02 10:45:24 -06:00
parent f894ba8f98
commit c8da78966b
7 changed files with 163 additions and 622 deletions

View File

@ -5,10 +5,13 @@ from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from apex import amp
import models.SPSR_networks as networks
from .base_model import BaseModel
from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss
from models.SPSR_modules.loss import GANLoss
import torchvision.utils as utils
logger = logging.getLogger('base')
import torch.nn.functional as F
@ -95,10 +98,13 @@ class SPSRModel(BaseModel):
self.netG.train()
self.netD.train()
self.netD_grad.train()
self.mega_batch_factor = 1
self.load() # load G and D if needed
# define losses, optimizer and scheduler
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
@ -139,12 +145,6 @@ class SPSRModel(BaseModel):
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
if train_opt['gan_type'] == 'wgan-gp':
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
# gradient penalty loss
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
self.l_gp_w = train_opt['gp_weigth']
# gradient_pixel_loss
if train_opt['gradient_pixel_weight'] > 0:
self.cri_pix_grad = nn.MSELoss().to(self.device)
@ -202,6 +202,12 @@ class SPSRModel(BaseModel):
self.optimizers.append(self.optimizer_D_grad)
# AMP
[self.netG, self.netD, self.netD_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
amp.initialize([self.netG, self.netD, self.netD_grad],
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
opt_level=self.amp_level, num_losses=3)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
@ -216,12 +222,12 @@ class SPSRModel(BaseModel):
def feed_data(self, data, need_HR=True):
# LR
self.var_L = data['LQ'].to(self.device)
self.var_L = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)]
if need_HR: # train or val
self.var_H = data['GT'].to(self.device)
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref.to(self.device), chunks=self.mega_batch_factor, dim=0)]
@ -247,118 +253,118 @@ class SPSRModel(BaseModel):
self.optimizer_G.zero_grad()
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
self.fake_H_grad = self.get_grad(self.fake_H)
self.var_H_grad = self.get_grad(self.var_H)
self.var_ref_grad = self.get_grad(self.var_ref)
self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
self.fake_H_branch = []
self.fake_H = []
self.grad_LR = []
for var_L, var_H, var_ref in zip(self.var_L, self.var_H, self.var_ref):
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
self.fake_H_branch.append(fake_H_branch.detach())
self.fake_H.append(fake_H.detach())
self.grad_LR.append(grad_LR.detach())
fake_H_grad = self.get_grad(fake_H)
var_H_grad = self.get_grad(var_H)
var_ref_grad = self.get_grad(var_ref)
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.var_H).detach()
fake_fea = self.netF(self.fake_H)
real_fea = self.netF(var_H).detach()
fake_fea = self.netF(fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
if self.cri_pix_grad: #gradient pixel loss
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad)
l_g_total += l_g_pix_grad
if self.cri_pix_branch: #branch pixel loss
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding)
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, var_H_grad_nopadding)
l_g_total += l_g_pix_grad_branch
if self.l_gan_w > 0:
# G gan + cls loss
pred_g_fake = self.netD(self.fake_H)
pred_d_real = self.netD(self.var_ref).detach()
pred_g_fake = self.netD(fake_H)
pred_d_real = self.netD(var_ref).detach()
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan
if self.cri_grad_gan:
# grad G gan + cls loss
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
pred_g_fake_grad = self.netD_grad(fake_H_grad)
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
l_g_total += l_g_gan_grad
l_g_total /= self.mega_batch_factor
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
l_g_total_scaled.backward()
l_g_total.backward()
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
self.optimizer_G.step()
if self.l_gan_w > 0:
# D
for p in self.netD.parameters():
p.requires_grad = True
self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
pred_d_real = self.netD(var_ref)
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2
if self.opt['train']['gan_type'] == 'wgan-gp':
batch_size = self.var_ref.size(0)
if self.random_pt.size(0) != batch_size:
self.random_pt.resize_(batch_size, 1, 1, 1)
self.random_pt.uniform_() # Draw random interpolation points
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
interp.requires_grad = True
interp_crit, _ = self.netD(interp)
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
l_d_total += l_d_gp
l_d_total.backward()
l_d_total /= self.mega_batch_factor
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
l_d_total_scaled.backward()
self.optimizer_D.step()
if self.cri_grad_gan:
for p in self.netD_grad.parameters():
p.requires_grad = True
self.optimizer_D_grad.zero_grad()
l_d_total_grad = 0
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
fake_H_grad = self.get_grad(fake_H)
var_ref_grad = self.get_grad(var_ref)
pred_d_real_grad = self.netD_grad(self.var_ref_grad)
pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G
pred_d_real_grad = self.netD_grad(var_ref_grad)
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
l_d_total_grad /= self.mega_batch_factor
l_d_total_grad.backward()
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
l_d_total_grad_scaled.backward()
self.optimizer_D_grad.step()
# Log sample images from first microbatch.
if step % 50 == 0:
import torchvision.utils as utils
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
# fed_LQ is not chunked.
utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
# set log
@ -368,11 +374,13 @@ class SPSRModel(BaseModel):
self.log_dict['l_g_pix'] = l_g_pix.item()
if self.cri_fea:
self.log_dict['l_g_fea'] = l_g_fea.item()
if self.l_gan_w > 0:
self.log_dict['l_g_gan'] = l_g_gan.item()
if self.cri_pix_branch: #branch pixel loss
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
if self.l_gan_w > 0:
# D
self.log_dict['l_d_real'] = l_d_real.item()
self.log_dict['l_d_fake'] = l_d_fake.item()
@ -394,7 +402,14 @@ class SPSRModel(BaseModel):
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
self.fake_H_branch = []
self.fake_H = []
self.grad_LR = []
for var_L in self.var_L:
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
self.fake_H_branch.append(fake_H_branch)
self.fake_H.append(fake_H)
self.grad_LR.append(grad_LR)
self.netG.train()
@ -403,13 +418,13 @@ class SPSRModel(BaseModel):
def get_current_visuals(self, need_HR=True):
out_dict = OrderedDict()
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
out_dict['LR'] = self.var_L[0].float().cpu()
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu()
out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H[0].float().cpu()
out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu()
out_dict['LR_grad'] = self.grad_LR[0].float().cpu()
if need_HR:
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
out_dict['GT'] = self.var_H[0].float().cpu()
return out_dict
def print_network(self):
@ -456,7 +471,31 @@ class SPSRModel(BaseModel):
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD)
def compute_fea_loss(self, real, fake):
if self.cri_fea is None:
return 0
with torch.no_grad():
real = real.unsqueeze(dim=0).to(self.device)
fake = fake.unsqueeze(dim=0).to(self.device)
real_fea = self.netF(real).detach()
fake_fea = self.netF(fake)
return self.cri_fea(fake_fea, real_fea).item()
def force_restore_swapout(self):
pass
def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
self.save_network(self.netD, 'D', iter_step)
self.save_network(self.netD_grad, 'D_grad', iter_step)
# override of load_network that allows loading partial params (like RRDB_PSNR_x4)
def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel):
network = network.module
pretrained_dict = torch.load(load_path)
model_dict = network.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
network.load_state_dict(model_dict)

View File

@ -4,8 +4,6 @@ import torch.nn as nn
import torch.nn.functional as F
import torchvision
from . import block as B
from . import spectral_norm as SN
class Get_gradient_nopadding(nn.Module):
@ -248,292 +246,6 @@ class Discriminator_VGG_128(nn.Module):
return x
# VGG style Discriminator with input size 96*96
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 64*64
class Discriminator_VGG_64(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_64, self).__init__()
# features
# hxw, c
# 64, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 32*32
class Discriminator_VGG_32(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_32, self).__init__()
# features
# hxw, c
# 32, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 256
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
# classifier
self.classifier = nn.Sequential(
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 16*16
class Discriminator_VGG_16(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_16, self).__init__()
# features
# hxw, c
# 16, 3
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 128
self.features = B.sequential(conv0, conv1, conv2, conv3)
# classifier
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(nn.Module):
def __init__(self):
super(Discriminator_VGG_128_SN, self).__init__()
# features
# hxw, c
# 128, 64
self.lrelu = nn.LeakyReLU(0.2, True)
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
# 64, 64
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
# 32, 128
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
# 16, 256
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 8, 512
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 4, 512
# classifier
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
def forward(self, x):
x = self.lrelu(self.conv0(x))
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
x = self.lrelu(self.conv6(x))
x = self.lrelu(self.conv7(x))
x = self.lrelu(self.conv8(x))
x = self.lrelu(self.conv9(x))
x = x.view(x.size(0), -1)
x = self.lrelu(self.linear0(x))
x = self.linear1(x)
return x
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class Discriminator_VGG_192(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_192, self).__init__()
# features
# hxw, c
# 192, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 96, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9, conv10, conv11)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
####################
# Perceptual Network
####################

View File

@ -29,6 +29,8 @@ def norm(norm_type, nc):
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
elif norm_type == 'group':
layer = nn.GroupNorm(8, nc)
else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
return layer

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
# Define GAN loss: [vanilla | lsgan | wgan-gp]
# Define GAN loss: [vanilla | lsgan]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
@ -14,19 +14,10 @@ class GANLoss(nn.Module):
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
@ -36,25 +27,3 @@ class GANLoss(nn.Module):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss
class GradientPenaltyLoss(nn.Module):
def __init__(self, device=torch.device('cpu')):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.grad_outputs = self.grad_outputs.to(device)
def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
self.grad_outputs.resize_(input.size()).fill_(1.0)
return self.grad_outputs
def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
loss = ((grad_interp_norm - 1)**2).mean()
return loss

View File

@ -1,149 +0,0 @@
'''
Copy from pytorch github repo
Spectral Normalization from https://arxiv.org/abs/1802.05957
'''
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
weight_mat = weight
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim])
height = weight_mat.size(0)
weight_mat = weight_mat.reshape(height, -1)
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
sigma = torch.dot(u, torch.matmul(weight_mat, v))
weight = weight / sigma
return weight, u
def remove(self, module):
weight = getattr(module, self.name)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_orig')
module.register_parameter(self.name, torch.nn.Parameter(weight))
def __call__(self, module, inputs):
if module.training:
weight, u = self.compute_weight(module)
setattr(module, self.name, weight)
setattr(module, self.name + '_u', u)
else:
r_g = getattr(module, self.name + '_orig').requires_grad
getattr(module, self.name).detach_().requires_grad_(r_g)
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
height = weight.size(dim)
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a
# buffer, which will cause weight to be included in the state dict
# and also supports nn.init due to shared storage.
module.register_buffer(fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_forward_pre_hook(fn)
return fn
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectal norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is 0, except for modules that are instances of
ConvTranspose1/2/3d, when it is 1
Returns:
The original module with the spectal norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
Linear (20 -> 40)
>>> m.weight_u.size()
torch.Size([20])
"""
if dim is None:
if isinstance(
module,
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))

View File

@ -18,7 +18,7 @@ class CharbonnierLoss(nn.Module):
return loss
# Define GAN loss: [vanilla | lsgan | wgan-gp]
# Define GAN loss: [vanilla | lsgan]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
@ -30,19 +30,10 @@ class GANLoss(nn.Module):
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
@ -57,29 +48,6 @@ class GANLoss(nn.Module):
return loss
class GradientPenaltyLoss(nn.Module):
def __init__(self, device=torch.device('cpu')):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.grad_outputs = self.grad_outputs.to(device)
def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
self.grad_outputs.resize_(input.size()).fill_(1.0)
return self.grad_outputs
def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
grad_outputs=grad_outputs, create_graph=True,
retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
loss = ((grad_interp_norm - 1)**2).mean()
return loss
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).

View File

@ -215,7 +215,7 @@ def main():
logger.info(message)
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation