Substantial SPSR mods & fixes
- Added in gradient accumulation via mega-batch-factor - Added AMP - Added missing train hooks - Added debug image outputs - Cleaned up including removing GradientPenaltyLoss, custom SpectralNorm - Removed all the custom discriminators
This commit is contained in:
parent
f894ba8f98
commit
c8da78966b
|
@ -5,10 +5,13 @@ from collections import OrderedDict
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import lr_scheduler
|
||||
from apex import amp
|
||||
|
||||
import models.SPSR_networks as networks
|
||||
from .base_model import BaseModel
|
||||
from models.SPSR_modules.loss import GANLoss, GradientPenaltyLoss
|
||||
from models.SPSR_modules.loss import GANLoss
|
||||
import torchvision.utils as utils
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
@ -95,10 +98,13 @@ class SPSRModel(BaseModel):
|
|||
self.netG.train()
|
||||
self.netD.train()
|
||||
self.netD_grad.train()
|
||||
self.mega_batch_factor = 1
|
||||
self.load() # load G and D if needed
|
||||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
|
||||
# G pixel loss
|
||||
if train_opt['pixel_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_criterion']
|
||||
|
@ -139,12 +145,6 @@ class SPSRModel(BaseModel):
|
|||
self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
|
||||
self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1
|
||||
|
||||
if train_opt['gan_type'] == 'wgan-gp':
|
||||
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
||||
# gradient penalty loss
|
||||
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
||||
self.l_gp_w = train_opt['gp_weigth']
|
||||
|
||||
# gradient_pixel_loss
|
||||
if train_opt['gradient_pixel_weight'] > 0:
|
||||
self.cri_pix_grad = nn.MSELoss().to(self.device)
|
||||
|
@ -202,6 +202,12 @@ class SPSRModel(BaseModel):
|
|||
|
||||
self.optimizers.append(self.optimizer_D_grad)
|
||||
|
||||
# AMP
|
||||
[self.netG, self.netD, self.netD_grad], [self.optimizer_G, self.optimizer_D, self.optimizer_D_grad] = \
|
||||
amp.initialize([self.netG, self.netD, self.netD_grad],
|
||||
[self.optimizer_G, self.optimizer_D, self.optimizer_D_grad],
|
||||
opt_level=self.amp_level, num_losses=3)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
|
@ -216,12 +222,12 @@ class SPSRModel(BaseModel):
|
|||
|
||||
def feed_data(self, data, need_HR=True):
|
||||
# LR
|
||||
self.var_L = data['LQ'].to(self.device)
|
||||
self.var_L = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
if need_HR: # train or val
|
||||
self.var_H = data['GT'].to(self.device)
|
||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.var_ref = input_ref.to(self.device)
|
||||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref.to(self.device), chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
|
||||
|
||||
|
@ -247,118 +253,118 @@ class SPSRModel(BaseModel):
|
|||
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||
self.fake_H_branch = []
|
||||
self.fake_H = []
|
||||
self.grad_LR = []
|
||||
for var_L, var_H, var_ref in zip(self.var_L, self.var_H, self.var_ref):
|
||||
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
|
||||
self.fake_H_branch.append(fake_H_branch.detach())
|
||||
self.fake_H.append(fake_H.detach())
|
||||
self.grad_LR.append(grad_LR.detach())
|
||||
|
||||
|
||||
self.fake_H_grad = self.get_grad(self.fake_H)
|
||||
self.var_H_grad = self.get_grad(self.var_H)
|
||||
self.var_ref_grad = self.get_grad(self.var_ref)
|
||||
self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
|
||||
|
||||
fake_H_grad = self.get_grad(fake_H)
|
||||
var_H_grad = self.get_grad(var_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(var_H).detach()
|
||||
fake_fea = self.netF(fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if self.cri_pix_grad: #gradient pixel loss
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad)
|
||||
l_g_total += l_g_pix_grad
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(fake_H_branch, var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
# G gan + cls loss
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
|
||||
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
if self.cri_grad_gan:
|
||||
# grad G gan + cls loss
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
l_g_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(self.var_H).detach()
|
||||
fake_fea = self.netF(self.fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if self.cri_pix_grad: #gradient pixel loss
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
|
||||
l_g_total += l_g_pix_grad
|
||||
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(self.fake_H_branch, self.var_H_grad_nopadding)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
|
||||
|
||||
# G gan + cls loss
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
pred_d_real = self.netD(self.var_ref).detach()
|
||||
|
||||
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
# grad G gan + cls loss
|
||||
|
||||
pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
|
||||
|
||||
l_g_gan_grad = self.l_gan_grad_w * (self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
|
||||
self.cri_grad_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) /2
|
||||
l_g_total += l_g_gan_grad
|
||||
|
||||
|
||||
l_g_total.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
if self.l_gan_w > 0:
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
l_d_total = 0
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
self.optimizer_D.zero_grad()
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
pred_d_real = self.netD(var_ref)
|
||||
pred_d_fake = self.netD(fake_H) # detach to avoid BP to G
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
batch_size = self.var_ref.size(0)
|
||||
if self.random_pt.size(0) != batch_size:
|
||||
self.random_pt.resize_(batch_size, 1, 1, 1)
|
||||
self.random_pt.uniform_() # Draw random interpolation points
|
||||
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
|
||||
interp.requires_grad = True
|
||||
interp_crit, _ = self.netD(interp)
|
||||
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
|
||||
l_d_total += l_d_gp
|
||||
l_d_total /= self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
|
||||
l_d_total.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
self.optimizer_D.step()
|
||||
if self.cri_grad_gan:
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
|
||||
for p in self.netD_grad.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
for var_ref, fake_H in zip(self.var_ref, self.fake_H):
|
||||
fake_H_grad = self.get_grad(fake_H)
|
||||
var_ref_grad = self.get_grad(var_ref)
|
||||
|
||||
self.optimizer_D_grad.zero_grad()
|
||||
l_d_total_grad = 0
|
||||
pred_d_real_grad = self.netD_grad(var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(fake_H_grad.detach()) # detach to avoid BP to G
|
||||
|
||||
|
||||
pred_d_real_grad = self.netD_grad(self.var_ref_grad)
|
||||
pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) # detach to avoid BP to G
|
||||
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
l_d_real_grad = self.cri_grad_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
|
||||
l_d_fake_grad = self.cri_grad_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
|
||||
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
|
||||
l_d_total_grad /= self.mega_batch_factor
|
||||
|
||||
with amp.scale_loss(l_d_total_grad, self.optimizer_D_grad, loss_id=2) as l_d_total_grad_scaled:
|
||||
l_d_total_grad_scaled.backward()
|
||||
|
||||
l_d_total_grad.backward()
|
||||
|
||||
self.optimizer_D_grad.step()
|
||||
self.optimizer_D_grad.step()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
import torchvision.utils as utils
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||
# fed_LQ is not chunked.
|
||||
utils.save_image(self.var_H.cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_L.cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fake_H.cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_H[0].cpu(), os.path.join(sample_save_path, "hr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.var_L[0].cpu(), os.path.join(sample_save_path, "lr", "%05i.png" % (step,)))
|
||||
utils.save_image(self.fake_H[0].cpu(), os.path.join(sample_save_path, "gen", "%05i.png" % (step,)))
|
||||
|
||||
|
||||
# set log
|
||||
|
@ -368,33 +374,42 @@ class SPSRModel(BaseModel):
|
|||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
if self.cri_fea:
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
if self.l_gan_w > 0:
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
|
||||
if self.cri_pix_branch: #branch pixel loss
|
||||
self.log_dict['l_g_pix_grad_branch'] = l_g_pix_grad_branch.item()
|
||||
|
||||
# D
|
||||
self.log_dict['l_d_real'] = l_d_real.item()
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
|
||||
# D_grad
|
||||
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
|
||||
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
|
||||
if self.l_gan_w > 0:
|
||||
# D
|
||||
self.log_dict['l_d_real'] = l_d_real.item()
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item()
|
||||
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
self.log_dict['l_d_gp'] = l_d_gp.item()
|
||||
# D outputs
|
||||
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
# D_grad
|
||||
self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
|
||||
self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()
|
||||
|
||||
# D_grad outputs
|
||||
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
|
||||
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
|
||||
if self.opt['train']['gan_type'] == 'wgan-gp':
|
||||
self.log_dict['l_d_gp'] = l_d_gp.item()
|
||||
# D outputs
|
||||
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
|
||||
# D_grad outputs
|
||||
self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
|
||||
self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())
|
||||
|
||||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H_branch, self.fake_H, self.grad_LR = self.netG(self.var_L)
|
||||
self.fake_H_branch = []
|
||||
self.fake_H = []
|
||||
self.grad_LR = []
|
||||
for var_L in self.var_L:
|
||||
fake_H_branch, fake_H, grad_LR = self.netG(var_L)
|
||||
self.fake_H_branch.append(fake_H_branch)
|
||||
self.fake_H.append(fake_H)
|
||||
self.grad_LR.append(grad_LR)
|
||||
|
||||
self.netG.train()
|
||||
|
||||
|
@ -403,13 +418,13 @@ class SPSRModel(BaseModel):
|
|||
|
||||
def get_current_visuals(self, need_HR=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
||||
out_dict['LR'] = self.var_L[0].float().cpu()
|
||||
|
||||
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
||||
out_dict['SR_branch'] = self.fake_H_branch.detach()[0].float().cpu()
|
||||
out_dict['LR_grad'] = self.grad_LR.detach()[0].float().cpu()
|
||||
out_dict['rlt'] = self.fake_H[0].float().cpu()
|
||||
out_dict['SR_branch'] = self.fake_H_branch[0].float().cpu()
|
||||
out_dict['LR_grad'] = self.grad_LR[0].float().cpu()
|
||||
if need_HR:
|
||||
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
||||
out_dict['GT'] = self.var_H[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
|
@ -456,7 +471,31 @@ class SPSRModel(BaseModel):
|
|||
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
||||
self.load_network(load_path_D, self.netD)
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
if self.cri_fea is None:
|
||||
return 0
|
||||
with torch.no_grad():
|
||||
real = real.unsqueeze(dim=0).to(self.device)
|
||||
fake = fake.unsqueeze(dim=0).to(self.device)
|
||||
real_fea = self.netF(real).detach()
|
||||
fake_fea = self.netF(fake)
|
||||
return self.cri_fea(fake_fea, real_fea).item()
|
||||
|
||||
def force_restore_swapout(self):
|
||||
pass
|
||||
|
||||
def save(self, iter_step):
|
||||
self.save_network(self.netG, 'G', iter_step)
|
||||
self.save_network(self.netD, 'D', iter_step)
|
||||
self.save_network(self.netD_grad, 'D_grad', iter_step)
|
||||
|
||||
# override of load_network that allows loading partial params (like RRDB_PSNR_x4)
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
if isinstance(network, nn.DataParallel):
|
||||
network = network.module
|
||||
pretrained_dict = torch.load(load_path)
|
||||
model_dict = network.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
|
||||
model_dict.update(pretrained_dict)
|
||||
network.load_state_dict(model_dict)
|
|
@ -4,8 +4,6 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from . import block as B
|
||||
from . import spectral_norm as SN
|
||||
|
||||
|
||||
|
||||
class Get_gradient_nopadding(nn.Module):
|
||||
|
@ -248,292 +246,6 @@ class Discriminator_VGG_128(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 96*96
|
||||
class Discriminator_VGG_96(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_96, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 96, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 64*64
|
||||
class Discriminator_VGG_64(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_64, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 64, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 32, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 32*32
|
||||
class Discriminator_VGG_32(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_32, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 32, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 16, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 256
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(256 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 16*16
|
||||
class Discriminator_VGG_16(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_16, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 16, 3
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 8, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 4, 128
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(128 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
# VGG style Discriminator with input size 128*128, Spectral Normalization
|
||||
class Discriminator_VGG_128_SN(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator_VGG_128_SN, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 128, 64
|
||||
self.lrelu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
||||
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
||||
# 64, 64
|
||||
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
||||
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
||||
# 32, 128
|
||||
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
||||
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
||||
# 16, 256
|
||||
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
||||
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||
# 8, 512
|
||||
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
||||
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
||||
# 4, 512
|
||||
|
||||
# classifier
|
||||
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
||||
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.lrelu(self.conv0(x))
|
||||
x = self.lrelu(self.conv1(x))
|
||||
x = self.lrelu(self.conv2(x))
|
||||
x = self.lrelu(self.conv3(x))
|
||||
x = self.lrelu(self.conv4(x))
|
||||
x = self.lrelu(self.conv5(x))
|
||||
x = self.lrelu(self.conv6(x))
|
||||
x = self.lrelu(self.conv7(x))
|
||||
x = self.lrelu(self.conv8(x))
|
||||
x = self.lrelu(self.conv9(x))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.lrelu(self.linear0(x))
|
||||
x = self.linear1(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator_VGG_96(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_96, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 96, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 3, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator_VGG_192(nn.Module):
|
||||
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
||||
super(Discriminator_VGG_192, self).__init__()
|
||||
# features
|
||||
# hxw, c
|
||||
# 192, 64
|
||||
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
||||
mode=mode)
|
||||
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 96, 64
|
||||
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 48, 128
|
||||
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 24, 256
|
||||
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 12, 512
|
||||
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 6, 512
|
||||
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
||||
act_type=act_type, mode=mode)
|
||||
# 3, 512
|
||||
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
||||
conv9, conv10, conv11)
|
||||
|
||||
# classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
####################
|
||||
# Perceptual Network
|
||||
####################
|
||||
|
|
|
@ -29,6 +29,8 @@ def norm(norm_type, nc):
|
|||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'group':
|
||||
layer = nn.GroupNorm(8, nc)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
||||
# Define GAN loss: [vanilla | lsgan]
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super(GANLoss, self).__init__()
|
||||
|
@ -14,19 +14,10 @@ class GANLoss(nn.Module):
|
|||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.gan_type == 'wgan-gp':
|
||||
|
||||
def wgan_loss(input, target):
|
||||
# target is boolean
|
||||
return -1 * input.mean() if target else input.mean()
|
||||
|
||||
self.loss = wgan_loss
|
||||
else:
|
||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||
|
||||
def get_target_label(self, input, target_is_real):
|
||||
if self.gan_type == 'wgan-gp':
|
||||
return target_is_real
|
||||
if target_is_real:
|
||||
return torch.empty_like(input).fill_(self.real_label_val)
|
||||
else:
|
||||
|
@ -36,25 +27,3 @@ class GANLoss(nn.Module):
|
|||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input, target_label)
|
||||
return loss
|
||||
|
||||
|
||||
class GradientPenaltyLoss(nn.Module):
|
||||
def __init__(self, device=torch.device('cpu')):
|
||||
super(GradientPenaltyLoss, self).__init__()
|
||||
self.register_buffer('grad_outputs', torch.Tensor())
|
||||
self.grad_outputs = self.grad_outputs.to(device)
|
||||
|
||||
def get_grad_outputs(self, input):
|
||||
if self.grad_outputs.size() != input.size():
|
||||
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
||||
return self.grad_outputs
|
||||
|
||||
def forward(self, interp, interp_crit):
|
||||
grad_outputs = self.get_grad_outputs(interp_crit)
|
||||
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
|
||||
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
||||
grad_interp_norm = grad_interp.norm(2, dim=1)
|
||||
|
||||
loss = ((grad_interp_norm - 1)**2).mean()
|
||||
return loss
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
'''
|
||||
Copy from pytorch github repo
|
||||
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
||||
'''
|
||||
import torch
|
||||
from torch.nn.functional import normalize
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SpectralNorm(object):
|
||||
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||
self.name = name
|
||||
self.dim = dim
|
||||
if n_power_iterations <= 0:
|
||||
raise ValueError('Expected n_power_iterations to be positive, but '
|
||||
'got n_power_iterations={}'.format(n_power_iterations))
|
||||
self.n_power_iterations = n_power_iterations
|
||||
self.eps = eps
|
||||
|
||||
def compute_weight(self, module):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
u = getattr(module, self.name + '_u')
|
||||
weight_mat = weight
|
||||
if self.dim != 0:
|
||||
# permute dim to front
|
||||
weight_mat = weight_mat.permute(self.dim,
|
||||
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
||||
height = weight_mat.size(0)
|
||||
weight_mat = weight_mat.reshape(height, -1)
|
||||
with torch.no_grad():
|
||||
for _ in range(self.n_power_iterations):
|
||||
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
||||
# are the first left and right singular vectors.
|
||||
# This power iteration produces approximations of `u` and `v`.
|
||||
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
||||
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
||||
|
||||
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
||||
weight = weight / sigma
|
||||
return weight, u
|
||||
|
||||
def remove(self, module):
|
||||
weight = getattr(module, self.name)
|
||||
delattr(module, self.name)
|
||||
delattr(module, self.name + '_u')
|
||||
delattr(module, self.name + '_orig')
|
||||
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
||||
|
||||
def __call__(self, module, inputs):
|
||||
if module.training:
|
||||
weight, u = self.compute_weight(module)
|
||||
setattr(module, self.name, weight)
|
||||
setattr(module, self.name + '_u', u)
|
||||
else:
|
||||
r_g = getattr(module, self.name + '_orig').requires_grad
|
||||
getattr(module, self.name).detach_().requires_grad_(r_g)
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name, n_power_iterations, dim, eps):
|
||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||
weight = module._parameters[name]
|
||||
height = weight.size(dim)
|
||||
|
||||
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
||||
delattr(module, fn.name)
|
||||
module.register_parameter(fn.name + "_orig", weight)
|
||||
# We still need to assign weight back as fn.name because all sorts of
|
||||
# things may assume that it exists, e.g., when initializing weights.
|
||||
# However, we can't directly assign as it could be an nn.Parameter and
|
||||
# gets added as a parameter. Instead, we register weight.data as a
|
||||
# buffer, which will cause weight to be included in the state dict
|
||||
# and also supports nn.init due to shared storage.
|
||||
module.register_buffer(fn.name, weight.data)
|
||||
module.register_buffer(fn.name + "_u", u)
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
|
||||
r"""Applies spectral normalization to a parameter in the given module.
|
||||
|
||||
.. math::
|
||||
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
||||
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
||||
|
||||
Spectral normalization stabilizes the training of discriminators (critics)
|
||||
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
||||
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
||||
power iteration method. If the dimension of the weight tensor is greater
|
||||
than 2, it is reshaped to 2D in power iteration method to get spectral
|
||||
norm. This is implemented via a hook that calculates spectral norm and
|
||||
rescales weight before every :meth:`~Module.forward` call.
|
||||
|
||||
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
||||
|
||||
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
||||
|
||||
Args:
|
||||
module (nn.Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
n_power_iterations (int, optional): number of power iterations to
|
||||
calculate spectal norm
|
||||
eps (float, optional): epsilon for numerical stability in
|
||||
calculating norms
|
||||
dim (int, optional): dimension corresponding to number of outputs,
|
||||
the default is 0, except for modules that are instances of
|
||||
ConvTranspose1/2/3d, when it is 1
|
||||
|
||||
Returns:
|
||||
The original module with the spectal norm hook
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = spectral_norm(nn.Linear(20, 40))
|
||||
Linear (20 -> 40)
|
||||
>>> m.weight_u.size()
|
||||
torch.Size([20])
|
||||
|
||||
"""
|
||||
if dim is None:
|
||||
if isinstance(
|
||||
module,
|
||||
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
|
||||
dim = 1
|
||||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
return module
|
||||
|
||||
|
||||
def remove_spectral_norm(module, name='weight'):
|
||||
r"""Removes the spectral normalization reparameterization from a module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
|
||||
Example:
|
||||
>>> m = spectral_norm(nn.Linear(40, 10))
|
||||
>>> remove_spectral_norm(m)
|
||||
"""
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||
hook.remove(module)
|
||||
del module._forward_pre_hooks[k]
|
||||
return module
|
||||
|
||||
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
|
|
@ -18,7 +18,7 @@ class CharbonnierLoss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
||||
# Define GAN loss: [vanilla | lsgan]
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super(GANLoss, self).__init__()
|
||||
|
@ -30,19 +30,10 @@ class GANLoss(nn.Module):
|
|||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.gan_type == 'wgan-gp':
|
||||
|
||||
def wgan_loss(input, target):
|
||||
# target is boolean
|
||||
return -1 * input.mean() if target else input.mean()
|
||||
|
||||
self.loss = wgan_loss
|
||||
else:
|
||||
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
||||
|
||||
def get_target_label(self, input, target_is_real):
|
||||
if self.gan_type == 'wgan-gp':
|
||||
return target_is_real
|
||||
if target_is_real:
|
||||
return torch.empty_like(input).fill_(self.real_label_val)
|
||||
else:
|
||||
|
@ -57,29 +48,6 @@ class GANLoss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
class GradientPenaltyLoss(nn.Module):
|
||||
def __init__(self, device=torch.device('cpu')):
|
||||
super(GradientPenaltyLoss, self).__init__()
|
||||
self.register_buffer('grad_outputs', torch.Tensor())
|
||||
self.grad_outputs = self.grad_outputs.to(device)
|
||||
|
||||
def get_grad_outputs(self, input):
|
||||
if self.grad_outputs.size() != input.size():
|
||||
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
||||
return self.grad_outputs
|
||||
|
||||
def forward(self, interp, interp_crit):
|
||||
grad_outputs = self.get_grad_outputs(interp_crit)
|
||||
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
|
||||
grad_outputs=grad_outputs, create_graph=True,
|
||||
retain_graph=True, only_inputs=True)[0]
|
||||
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
||||
grad_interp_norm = grad_interp.norm(2, dim=1)
|
||||
|
||||
loss = ((grad_interp_norm - 1)**2).mean()
|
||||
return loss
|
||||
|
||||
|
||||
# Frequency Domain Perceptual Loss, from https://github.com/sdv4/FDPL
|
||||
# Utilizes pre-computed perceptual_weights. To generate these from your dataset, see data_scripts/compute_fdpl_perceptual_weights.py
|
||||
# In practice, per the paper, these precomputed weights can generally be used across broad image classes (e.g. all photographs).
|
||||
|
|
|
@ -215,7 +215,7 @@ def main():
|
|||
logger.info(message)
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsr'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
|
|
Loading…
Reference in New Issue
Block a user