Rework stylegan2 divergence losses
Notably: include unet loss
This commit is contained in:
parent
ea94b93a37
commit
99f0cfaab5
|
@ -0,0 +1,14 @@
|
|||
from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss
|
||||
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss
|
||||
|
||||
|
||||
def create_stylegan2_loss(opt_loss, env):
|
||||
type = opt_loss['type']
|
||||
if type == 'stylegan2_divergence':
|
||||
return StyleGan2DivergenceLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_pathlen':
|
||||
return StyleGan2PathLengthLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_unet_divergence':
|
||||
return StyleGan2UnetDivergenceLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -13,6 +13,7 @@ from torch import nn
|
|||
from torch.autograd import grad as torch_grad
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.steps.losses import ConfigurableLoss
|
||||
from utils.util import checkpoint
|
||||
|
||||
try:
|
||||
|
@ -304,6 +305,9 @@ class StyleGan2Augmentor(nn.Module):
|
|||
if detach:
|
||||
images = images.detach()
|
||||
|
||||
# Save away for use elsewhere (e.g. unet loss)
|
||||
self.aug_images = images
|
||||
|
||||
return self.D(images)
|
||||
|
||||
|
||||
|
@ -693,4 +697,69 @@ class StyleGan2Discriminator(nn.Module):
|
|||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
|
||||
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return divergence_loss
|
||||
|
||||
|
||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.w_styles = opt['w_styles']
|
||||
self.gen = opt['gen']
|
||||
self.pl_mean = None
|
||||
from models.archs.stylegan.stylegan2 import EMA
|
||||
self.pl_length_ma = EMA(.99)
|
||||
|
||||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.archs.stylegan.stylegan2 import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
||||
return pl_loss
|
||||
else:
|
||||
print("Path length loss returned NaN!")
|
||||
|
||||
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
||||
return 0
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.archs.stylegan.stylegan2 import attn_and_ff
|
||||
from models.steps.losses import ConfigurableLoss
|
||||
|
||||
|
||||
def leaky_relu(p=0.2):
|
||||
|
@ -132,4 +136,105 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
|||
x = up_block(x, res)
|
||||
|
||||
dec_out = self.conv_out(x)
|
||||
return dec_out
|
||||
return dec_out, enc_out
|
||||
|
||||
|
||||
def warmup(start, end, max_steps, current_step):
|
||||
if current_step > max_steps:
|
||||
return end
|
||||
return (end - start) * (current_step / max_steps) + start
|
||||
|
||||
|
||||
def mask_src_tgt(source, target, mask):
|
||||
return source * mask + (1 - mask) * target
|
||||
|
||||
|
||||
def cutmix(source, target, coors, alpha = 1.):
|
||||
source, target = map(torch.clone, (source, target))
|
||||
((y0, y1), (x0, x1)), _ = coors
|
||||
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
|
||||
return source
|
||||
|
||||
|
||||
def cutmix_coordinates(height, width, alpha = 1.):
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
|
||||
cx = np.random.uniform(0, width)
|
||||
cy = np.random.uniform(0, height)
|
||||
w = width * np.sqrt(1 - lam)
|
||||
h = height * np.sqrt(1 - lam)
|
||||
x0 = int(np.round(max(cx - w / 2, 0)))
|
||||
x1 = int(np.round(min(cx + w / 2, width)))
|
||||
y0 = int(np.round(max(cy - h / 2, 0)))
|
||||
y1 = int(np.round(min(cy + h / 2, height)))
|
||||
|
||||
return ((y0, y1), (x0, x1)), lam
|
||||
|
||||
|
||||
class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
self.image_size = opt['image_size']
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake_dec, fake_enc = D(fake_input)
|
||||
fake_aug_images = D.aug_images
|
||||
if self.for_gen:
|
||||
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
|
||||
else:
|
||||
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
|
||||
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
|
||||
apply_cutmix = random() < cutmix_prob
|
||||
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real_dec, real_enc = D(real_input)
|
||||
real_aug_images = D.aug_images
|
||||
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
|
||||
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
|
||||
divergence_loss = enc_divergence + dec_divergence * dec_loss_coef
|
||||
|
||||
if apply_cutmix:
|
||||
mask = cutmix(
|
||||
torch.ones_like(real_dec),
|
||||
torch.zeros_like(real_dec),
|
||||
cutmix_coordinates(self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
if random() > 0.5:
|
||||
mask = 1 - mask
|
||||
|
||||
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
|
||||
cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images)
|
||||
|
||||
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
|
||||
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
|
||||
disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence
|
||||
|
||||
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
|
||||
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
|
||||
self.last_cr_loss = cr_loss.clone().detach().item()
|
||||
|
||||
disc_loss = disc_loss + cr_loss * dec_loss_coef
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
disc_loss = disc_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return disc_loss
|
|
@ -15,6 +15,9 @@ def create_loss(opt_loss, env):
|
|||
if 'teco_' in type:
|
||||
from models.steps.tecogan_losses import create_teco_loss
|
||||
return create_teco_loss(opt_loss, env)
|
||||
elif 'stylegan2_' in type:
|
||||
from models.archs.stylegan import create_stylegan2_loss
|
||||
return create_stylegan2_loss(opt_loss, env)
|
||||
elif type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
elif type == 'direct':
|
||||
|
@ -37,10 +40,6 @@ def create_loss(opt_loss, env):
|
|||
return RecurrentLoss(opt_loss, env)
|
||||
elif type == 'for_element':
|
||||
return ForElementLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_divergence':
|
||||
return StyleGan2DivergenceLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_pathlen':
|
||||
return StyleGan2PathLengthLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -487,68 +486,3 @@ class ForElementLoss(ConfigurableLoss):
|
|||
|
||||
def clear_metrics(self):
|
||||
self.loss.clear_metrics()
|
||||
|
||||
|
||||
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return divergence_loss
|
||||
|
||||
|
||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.w_styles = opt['w_styles']
|
||||
self.gen = opt['gen']
|
||||
self.pl_mean = None
|
||||
from models.archs.stylegan.stylegan2 import EMA
|
||||
self.pl_length_ma = EMA(.99)
|
||||
|
||||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.archs.stylegan.stylegan2 import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
||||
return pl_loss
|
||||
else:
|
||||
print("Path length loss returned NaN!")
|
||||
|
||||
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user