Rework stylegan2 divergence losses

Notably: include unet loss
This commit is contained in:
James Betker 2020-11-15 11:26:44 -07:00
parent ea94b93a37
commit 99f0cfaab5
4 changed files with 193 additions and 71 deletions

View File

@ -0,0 +1,14 @@
from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss
def create_stylegan2_loss(opt_loss, env):
type = opt_loss['type']
if type == 'stylegan2_divergence':
return StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
return StyleGan2PathLengthLoss(opt_loss, env)
elif type == 'stylegan2_unet_divergence':
return StyleGan2UnetDivergenceLoss(opt_loss, env)
else:
raise NotImplementedError

View File

@ -13,6 +13,7 @@ from torch import nn
from torch.autograd import grad as torch_grad from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize from vector_quantize_pytorch import VectorQuantize
from models.steps.losses import ConfigurableLoss
from utils.util import checkpoint from utils.util import checkpoint
try: try:
@ -304,6 +305,9 @@ class StyleGan2Augmentor(nn.Module):
if detach: if detach:
images = images.detach() images = images.detach()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
return self.D(images) return self.D(images)
@ -694,3 +698,68 @@ class StyleGan2Discriminator(nn.Module):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}: if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
class StyleGan2DivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan.stylegan2 import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.archs.stylegan.stylegan2 import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2 import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0

View File

@ -1,10 +1,14 @@
from functools import partial from functools import partial
from math import log2 from math import log2
from random import random
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from models.archs.stylegan.stylegan2 import attn_and_ff from models.archs.stylegan.stylegan2 import attn_and_ff
from models.steps.losses import ConfigurableLoss
def leaky_relu(p=0.2): def leaky_relu(p=0.2):
@ -132,4 +136,105 @@ class StyleGan2UnetDiscriminator(nn.Module):
x = up_block(x, res) x = up_block(x, res)
dec_out = self.conv_out(x) dec_out = self.conv_out(x)
return dec_out return dec_out, enc_out
def warmup(start, end, max_steps, current_step):
if current_step > max_steps:
return end
return (end - start) * (current_step / max_steps) + start
def mask_src_tgt(source, target, mask):
return source * mask + (1 - mask) * target
def cutmix(source, target, coors, alpha = 1.):
source, target = map(torch.clone, (source, target))
((y0, y1), (x0, x1)), _ = coors
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
return source
def cutmix_coordinates(height, width, alpha = 1.):
lam = np.random.beta(alpha, alpha)
cx = np.random.uniform(0, width)
cy = np.random.uniform(0, height)
w = width * np.sqrt(1 - lam)
h = height * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, width)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, height)))
return ((y0, y1), (x0, x1)), lam
class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.image_size = opt['image_size']
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake_dec, fake_enc = D(fake_input)
fake_aug_images = D.aug_images
if self.for_gen:
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
else:
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
apply_cutmix = random() < cutmix_prob
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real_dec, real_enc = D(real_input)
real_aug_images = D.aug_images
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
divergence_loss = enc_divergence + dec_divergence * dec_loss_coef
if apply_cutmix:
mask = cutmix(
torch.ones_like(real_dec),
torch.zeros_like(real_dec),
cutmix_coordinates(self.image_size, self.image_size)
)
if random() > 0.5:
mask = 1 - mask
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images)
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
self.last_cr_loss = cr_loss.clone().detach().item()
disc_loss = disc_loss + cr_loss * dec_loss_coef
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
disc_loss = disc_loss + gp
real_input.requires_grad_(requires_grad=False)
return disc_loss

View File

@ -15,6 +15,9 @@ def create_loss(opt_loss, env):
if 'teco_' in type: if 'teco_' in type:
from models.steps.tecogan_losses import create_teco_loss from models.steps.tecogan_losses import create_teco_loss
return create_teco_loss(opt_loss, env) return create_teco_loss(opt_loss, env)
elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss
return create_stylegan2_loss(opt_loss, env)
elif type == 'pix': elif type == 'pix':
return PixLoss(opt_loss, env) return PixLoss(opt_loss, env)
elif type == 'direct': elif type == 'direct':
@ -37,10 +40,6 @@ def create_loss(opt_loss, env):
return RecurrentLoss(opt_loss, env) return RecurrentLoss(opt_loss, env)
elif type == 'for_element': elif type == 'for_element':
return ForElementLoss(opt_loss, env) return ForElementLoss(opt_loss, env)
elif type == 'stylegan2_divergence':
return StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
return StyleGan2PathLengthLoss(opt_loss, env)
else: else:
raise NotImplementedError raise NotImplementedError
@ -487,68 +486,3 @@ class ForElementLoss(ConfigurableLoss):
def clear_metrics(self): def clear_metrics(self):
self.loss.clear_metrics() self.loss.clear_metrics()
class StyleGan2DivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
class StyleGan2PathLengthLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan.stylegan2 import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.archs.stylegan.stylegan2 import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2 import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0