Rework stylegan2 divergence losses

Notably: include unet loss
This commit is contained in:
James Betker 2020-11-15 11:26:44 -07:00
parent ea94b93a37
commit 99f0cfaab5
4 changed files with 193 additions and 71 deletions

View File

@ -0,0 +1,14 @@
from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss
def create_stylegan2_loss(opt_loss, env):
type = opt_loss['type']
if type == 'stylegan2_divergence':
return StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
return StyleGan2PathLengthLoss(opt_loss, env)
elif type == 'stylegan2_unet_divergence':
return StyleGan2UnetDivergenceLoss(opt_loss, env)
else:
raise NotImplementedError

View File

@ -13,6 +13,7 @@ from torch import nn
from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
from models.steps.losses import ConfigurableLoss
from utils.util import checkpoint
try:
@ -304,6 +305,9 @@ class StyleGan2Augmentor(nn.Module):
if detach:
images = images.detach()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
return self.D(images)
@ -693,4 +697,69 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
class StyleGan2DivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan.stylegan2 import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.archs.stylegan.stylegan2 import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2 import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0

View File

@ -1,10 +1,14 @@
from functools import partial
from math import log2
from random import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.archs.stylegan.stylegan2 import attn_and_ff
from models.steps.losses import ConfigurableLoss
def leaky_relu(p=0.2):
@ -132,4 +136,105 @@ class StyleGan2UnetDiscriminator(nn.Module):
x = up_block(x, res)
dec_out = self.conv_out(x)
return dec_out
return dec_out, enc_out
def warmup(start, end, max_steps, current_step):
if current_step > max_steps:
return end
return (end - start) * (current_step / max_steps) + start
def mask_src_tgt(source, target, mask):
return source * mask + (1 - mask) * target
def cutmix(source, target, coors, alpha = 1.):
source, target = map(torch.clone, (source, target))
((y0, y1), (x0, x1)), _ = coors
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
return source
def cutmix_coordinates(height, width, alpha = 1.):
lam = np.random.beta(alpha, alpha)
cx = np.random.uniform(0, width)
cy = np.random.uniform(0, height)
w = width * np.sqrt(1 - lam)
h = height * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, width)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, height)))
return ((y0, y1), (x0, x1)), lam
class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.image_size = opt['image_size']
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake_dec, fake_enc = D(fake_input)
fake_aug_images = D.aug_images
if self.for_gen:
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
else:
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
apply_cutmix = random() < cutmix_prob
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real_dec, real_enc = D(real_input)
real_aug_images = D.aug_images
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
divergence_loss = enc_divergence + dec_divergence * dec_loss_coef
if apply_cutmix:
mask = cutmix(
torch.ones_like(real_dec),
torch.zeros_like(real_dec),
cutmix_coordinates(self.image_size, self.image_size)
)
if random() > 0.5:
mask = 1 - mask
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images)
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
self.last_cr_loss = cr_loss.clone().detach().item()
disc_loss = disc_loss + cr_loss * dec_loss_coef
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
disc_loss = disc_loss + gp
real_input.requires_grad_(requires_grad=False)
return disc_loss

View File

@ -15,6 +15,9 @@ def create_loss(opt_loss, env):
if 'teco_' in type:
from models.steps.tecogan_losses import create_teco_loss
return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss
return create_stylegan2_loss(opt_loss, env)
elif type == 'pix':
return PixLoss(opt_loss, env)
elif type == 'direct':
@ -37,10 +40,6 @@ def create_loss(opt_loss, env):
return RecurrentLoss(opt_loss, env)
elif type == 'for_element':
return ForElementLoss(opt_loss, env)
elif type == 'stylegan2_divergence':
return StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
return StyleGan2PathLengthLoss(opt_loss, env)
else:
raise NotImplementedError
@ -487,68 +486,3 @@ class ForElementLoss(ConfigurableLoss):
def clear_metrics(self):
self.loss.clear_metrics()
class StyleGan2DivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan.stylegan2 import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.archs.stylegan.stylegan2 import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2 import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0