forked from mrq/DL-Art-School
Rework stylegan2 divergence losses
Notably: include unet loss
This commit is contained in:
parent
ea94b93a37
commit
99f0cfaab5
|
@ -0,0 +1,14 @@
|
||||||
|
from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss
|
||||||
|
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss
|
||||||
|
|
||||||
|
|
||||||
|
def create_stylegan2_loss(opt_loss, env):
|
||||||
|
type = opt_loss['type']
|
||||||
|
if type == 'stylegan2_divergence':
|
||||||
|
return StyleGan2DivergenceLoss(opt_loss, env)
|
||||||
|
elif type == 'stylegan2_pathlen':
|
||||||
|
return StyleGan2PathLengthLoss(opt_loss, env)
|
||||||
|
elif type == 'stylegan2_unet_divergence':
|
||||||
|
return StyleGan2UnetDivergenceLoss(opt_loss, env)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
|
@ -13,6 +13,7 @@ from torch import nn
|
||||||
from torch.autograd import grad as torch_grad
|
from torch.autograd import grad as torch_grad
|
||||||
from vector_quantize_pytorch import VectorQuantize
|
from vector_quantize_pytorch import VectorQuantize
|
||||||
|
|
||||||
|
from models.steps.losses import ConfigurableLoss
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -304,6 +305,9 @@ class StyleGan2Augmentor(nn.Module):
|
||||||
if detach:
|
if detach:
|
||||||
images = images.detach()
|
images = images.detach()
|
||||||
|
|
||||||
|
# Save away for use elsewhere (e.g. unet loss)
|
||||||
|
self.aug_images = images
|
||||||
|
|
||||||
return self.D(images)
|
return self.D(images)
|
||||||
|
|
||||||
|
|
||||||
|
@ -694,3 +698,68 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.real = opt['real']
|
||||||
|
self.fake = opt['fake']
|
||||||
|
self.discriminator = opt['discriminator']
|
||||||
|
self.for_gen = opt['gen_loss']
|
||||||
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||||
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
real_input = state[self.real]
|
||||||
|
fake_input = state[self.fake]
|
||||||
|
if self.noise != 0:
|
||||||
|
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||||
|
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||||
|
|
||||||
|
D = self.env['discriminators'][self.discriminator]
|
||||||
|
fake = D(fake_input)
|
||||||
|
if self.for_gen:
|
||||||
|
return fake.mean()
|
||||||
|
else:
|
||||||
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||||
|
real = D(real_input)
|
||||||
|
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||||
|
|
||||||
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
|
gp = gradient_penalty(real_input, real)
|
||||||
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
|
divergence_loss = divergence_loss + gp
|
||||||
|
|
||||||
|
real_input.requires_grad_(requires_grad=False)
|
||||||
|
return divergence_loss
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.w_styles = opt['w_styles']
|
||||||
|
self.gen = opt['gen']
|
||||||
|
self.pl_mean = None
|
||||||
|
from models.archs.stylegan.stylegan2 import EMA
|
||||||
|
self.pl_length_ma = EMA(.99)
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
w_styles = state[self.w_styles]
|
||||||
|
gen = state[self.gen]
|
||||||
|
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
||||||
|
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||||
|
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||||
|
|
||||||
|
from models.archs.stylegan.stylegan2 import is_empty
|
||||||
|
if not is_empty(self.pl_mean):
|
||||||
|
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||||
|
if not torch.isnan(pl_loss):
|
||||||
|
return pl_loss
|
||||||
|
else:
|
||||||
|
print("Path length loss returned NaN!")
|
||||||
|
|
||||||
|
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
||||||
|
return 0
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import log2
|
from math import log2
|
||||||
|
from random import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2 import attn_and_ff
|
from models.archs.stylegan.stylegan2 import attn_and_ff
|
||||||
|
from models.steps.losses import ConfigurableLoss
|
||||||
|
|
||||||
|
|
||||||
def leaky_relu(p=0.2):
|
def leaky_relu(p=0.2):
|
||||||
|
@ -132,4 +136,105 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
||||||
x = up_block(x, res)
|
x = up_block(x, res)
|
||||||
|
|
||||||
dec_out = self.conv_out(x)
|
dec_out = self.conv_out(x)
|
||||||
return dec_out
|
return dec_out, enc_out
|
||||||
|
|
||||||
|
|
||||||
|
def warmup(start, end, max_steps, current_step):
|
||||||
|
if current_step > max_steps:
|
||||||
|
return end
|
||||||
|
return (end - start) * (current_step / max_steps) + start
|
||||||
|
|
||||||
|
|
||||||
|
def mask_src_tgt(source, target, mask):
|
||||||
|
return source * mask + (1 - mask) * target
|
||||||
|
|
||||||
|
|
||||||
|
def cutmix(source, target, coors, alpha = 1.):
|
||||||
|
source, target = map(torch.clone, (source, target))
|
||||||
|
((y0, y1), (x0, x1)), _ = coors
|
||||||
|
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
|
||||||
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
def cutmix_coordinates(height, width, alpha = 1.):
|
||||||
|
lam = np.random.beta(alpha, alpha)
|
||||||
|
|
||||||
|
cx = np.random.uniform(0, width)
|
||||||
|
cy = np.random.uniform(0, height)
|
||||||
|
w = width * np.sqrt(1 - lam)
|
||||||
|
h = height * np.sqrt(1 - lam)
|
||||||
|
x0 = int(np.round(max(cx - w / 2, 0)))
|
||||||
|
x1 = int(np.round(min(cx + w / 2, width)))
|
||||||
|
y0 = int(np.round(max(cy - h / 2, 0)))
|
||||||
|
y1 = int(np.round(min(cy + h / 2, height)))
|
||||||
|
|
||||||
|
return ((y0, y1), (x0, x1)), lam
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.real = opt['real']
|
||||||
|
self.fake = opt['fake']
|
||||||
|
self.discriminator = opt['discriminator']
|
||||||
|
self.for_gen = opt['gen_loss']
|
||||||
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||||
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
|
self.image_size = opt['image_size']
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
real_input = state[self.real]
|
||||||
|
fake_input = state[self.fake]
|
||||||
|
if self.noise != 0:
|
||||||
|
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||||
|
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||||
|
|
||||||
|
D = self.env['discriminators'][self.discriminator]
|
||||||
|
fake_dec, fake_enc = D(fake_input)
|
||||||
|
fake_aug_images = D.aug_images
|
||||||
|
if self.for_gen:
|
||||||
|
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
|
||||||
|
else:
|
||||||
|
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
|
||||||
|
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
|
||||||
|
apply_cutmix = random() < cutmix_prob
|
||||||
|
|
||||||
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||||
|
real_dec, real_enc = D(real_input)
|
||||||
|
real_aug_images = D.aug_images
|
||||||
|
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
|
||||||
|
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
|
||||||
|
divergence_loss = enc_divergence + dec_divergence * dec_loss_coef
|
||||||
|
|
||||||
|
if apply_cutmix:
|
||||||
|
mask = cutmix(
|
||||||
|
torch.ones_like(real_dec),
|
||||||
|
torch.zeros_like(real_dec),
|
||||||
|
cutmix_coordinates(self.image_size, self.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
if random() > 0.5:
|
||||||
|
mask = 1 - mask
|
||||||
|
|
||||||
|
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
|
||||||
|
cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images)
|
||||||
|
|
||||||
|
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
|
||||||
|
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
|
||||||
|
disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence
|
||||||
|
|
||||||
|
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
|
||||||
|
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
|
||||||
|
self.last_cr_loss = cr_loss.clone().detach().item()
|
||||||
|
|
||||||
|
disc_loss = disc_loss + cr_loss * dec_loss_coef
|
||||||
|
|
||||||
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
|
gp = gradient_penalty(real_input, real)
|
||||||
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
|
disc_loss = disc_loss + gp
|
||||||
|
|
||||||
|
real_input.requires_grad_(requires_grad=False)
|
||||||
|
return disc_loss
|
|
@ -15,6 +15,9 @@ def create_loss(opt_loss, env):
|
||||||
if 'teco_' in type:
|
if 'teco_' in type:
|
||||||
from models.steps.tecogan_losses import create_teco_loss
|
from models.steps.tecogan_losses import create_teco_loss
|
||||||
return create_teco_loss(opt_loss, env)
|
return create_teco_loss(opt_loss, env)
|
||||||
|
elif 'stylegan2_' in type:
|
||||||
|
from models.archs.stylegan import create_stylegan2_loss
|
||||||
|
return create_stylegan2_loss(opt_loss, env)
|
||||||
elif type == 'pix':
|
elif type == 'pix':
|
||||||
return PixLoss(opt_loss, env)
|
return PixLoss(opt_loss, env)
|
||||||
elif type == 'direct':
|
elif type == 'direct':
|
||||||
|
@ -37,10 +40,6 @@ def create_loss(opt_loss, env):
|
||||||
return RecurrentLoss(opt_loss, env)
|
return RecurrentLoss(opt_loss, env)
|
||||||
elif type == 'for_element':
|
elif type == 'for_element':
|
||||||
return ForElementLoss(opt_loss, env)
|
return ForElementLoss(opt_loss, env)
|
||||||
elif type == 'stylegan2_divergence':
|
|
||||||
return StyleGan2DivergenceLoss(opt_loss, env)
|
|
||||||
elif type == 'stylegan2_pathlen':
|
|
||||||
return StyleGan2PathLengthLoss(opt_loss, env)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -487,68 +486,3 @@ class ForElementLoss(ConfigurableLoss):
|
||||||
|
|
||||||
def clear_metrics(self):
|
def clear_metrics(self):
|
||||||
self.loss.clear_metrics()
|
self.loss.clear_metrics()
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.real = opt['real']
|
|
||||||
self.fake = opt['fake']
|
|
||||||
self.discriminator = opt['discriminator']
|
|
||||||
self.for_gen = opt['gen_loss']
|
|
||||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
|
||||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
real_input = state[self.real]
|
|
||||||
fake_input = state[self.fake]
|
|
||||||
if self.noise != 0:
|
|
||||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
|
||||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
|
||||||
|
|
||||||
D = self.env['discriminators'][self.discriminator]
|
|
||||||
fake = D(fake_input)
|
|
||||||
if self.for_gen:
|
|
||||||
return fake.mean()
|
|
||||||
else:
|
|
||||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
|
||||||
real = D(real_input)
|
|
||||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
|
||||||
|
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
|
||||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
|
||||||
gp = gradient_penalty(real_input, real)
|
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
|
||||||
divergence_loss = divergence_loss + gp
|
|
||||||
|
|
||||||
real_input.requires_grad_(requires_grad=False)
|
|
||||||
return divergence_loss
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super().__init__(opt, env)
|
|
||||||
self.w_styles = opt['w_styles']
|
|
||||||
self.gen = opt['gen']
|
|
||||||
self.pl_mean = None
|
|
||||||
from models.archs.stylegan.stylegan2 import EMA
|
|
||||||
self.pl_length_ma = EMA(.99)
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
w_styles = state[self.w_styles]
|
|
||||||
gen = state[self.gen]
|
|
||||||
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
|
||||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
|
||||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2 import is_empty
|
|
||||||
if not is_empty(self.pl_mean):
|
|
||||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
|
||||||
if not torch.isnan(pl_loss):
|
|
||||||
return pl_loss
|
|
||||||
else:
|
|
||||||
print("Path length loss returned NaN!")
|
|
||||||
|
|
||||||
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
|
||||||
return 0
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user