From 99f0cfaab51a92ab7269febab0573750433b2277 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 Nov 2020 11:26:44 -0700 Subject: [PATCH] Rework stylegan2 divergence losses Notably: include unet loss --- codes/models/archs/stylegan/__init__.py | 14 +++ codes/models/archs/stylegan/stylegan2.py | 71 +++++++++++- .../archs/stylegan/stylegan2_unet_disc.py | 107 +++++++++++++++++- codes/models/steps/losses.py | 72 +----------- 4 files changed, 193 insertions(+), 71 deletions(-) diff --git a/codes/models/archs/stylegan/__init__.py b/codes/models/archs/stylegan/__init__.py index e69de29b..4ab78ddf 100644 --- a/codes/models/archs/stylegan/__init__.py +++ b/codes/models/archs/stylegan/__init__.py @@ -0,0 +1,14 @@ +from models.archs.stylegan.stylegan2 import StyleGan2DivergenceLoss, StyleGan2PathLengthLoss +from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDivergenceLoss + + +def create_stylegan2_loss(opt_loss, env): + type = opt_loss['type'] + if type == 'stylegan2_divergence': + return StyleGan2DivergenceLoss(opt_loss, env) + elif type == 'stylegan2_pathlen': + return StyleGan2PathLengthLoss(opt_loss, env) + elif type == 'stylegan2_unet_divergence': + return StyleGan2UnetDivergenceLoss(opt_loss, env) + else: + raise NotImplementedError \ No newline at end of file diff --git a/codes/models/archs/stylegan/stylegan2.py b/codes/models/archs/stylegan/stylegan2.py index a676ef95..adba029f 100644 --- a/codes/models/archs/stylegan/stylegan2.py +++ b/codes/models/archs/stylegan/stylegan2.py @@ -13,6 +13,7 @@ from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize +from models.steps.losses import ConfigurableLoss from utils.util import checkpoint try: @@ -304,6 +305,9 @@ class StyleGan2Augmentor(nn.Module): if detach: images = images.detach() + # Save away for use elsewhere (e.g. unet loss) + self.aug_images = images + return self.D(images) @@ -693,4 +697,69 @@ class StyleGan2Discriminator(nn.Module): def _init_weights(self): for m in self.modules(): if type(m) in {nn.Conv2d, nn.Linear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') \ No newline at end of file + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + +class StyleGan2DivergenceLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.real = opt['real'] + self.fake = opt['fake'] + self.discriminator = opt['discriminator'] + self.for_gen = opt['gen_loss'] + self.gp_frequency = opt['gradient_penalty_frequency'] + self.noise = opt['noise'] if 'noise' in opt.keys() else 0 + + def forward(self, net, state): + real_input = state[self.real] + fake_input = state[self.fake] + if self.noise != 0: + fake_input = fake_input + torch.rand_like(fake_input) * self.noise + real_input = real_input + torch.rand_like(real_input) * self.noise + + D = self.env['discriminators'][self.discriminator] + fake = D(fake_input) + if self.for_gen: + return fake.mean() + else: + real_input.requires_grad_() # <-- Needed to compute gradients on the input. + real = D(real_input) + divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() + + # Apply gradient penalty. TODO: migrate this elsewhere. + if self.env['step'] % self.gp_frequency == 0: + from models.archs.stylegan.stylegan2 import gradient_penalty + gp = gradient_penalty(real_input, real) + self.metrics.append(("gradient_penalty", gp.clone().detach())) + divergence_loss = divergence_loss + gp + + real_input.requires_grad_(requires_grad=False) + return divergence_loss + + +class StyleGan2PathLengthLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.w_styles = opt['w_styles'] + self.gen = opt['gen'] + self.pl_mean = None + from models.archs.stylegan.stylegan2 import EMA + self.pl_length_ma = EMA(.99) + + def forward(self, net, state): + w_styles = state[self.w_styles] + gen = state[self.gen] + from models.archs.stylegan.stylegan2 import calc_pl_lengths + pl_lengths = calc_pl_lengths(w_styles, gen) + avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) + + from models.archs.stylegan.stylegan2 import is_empty + if not is_empty(self.pl_mean): + pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() + if not torch.isnan(pl_loss): + return pl_loss + else: + print("Path length loss returned NaN!") + + self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) + return 0 diff --git a/codes/models/archs/stylegan/stylegan2_unet_disc.py b/codes/models/archs/stylegan/stylegan2_unet_disc.py index f506b83e..a4c791d0 100644 --- a/codes/models/archs/stylegan/stylegan2_unet_disc.py +++ b/codes/models/archs/stylegan/stylegan2_unet_disc.py @@ -1,10 +1,14 @@ from functools import partial from math import log2 +from random import random +import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from models.archs.stylegan.stylegan2 import attn_and_ff +from models.steps.losses import ConfigurableLoss def leaky_relu(p=0.2): @@ -132,4 +136,105 @@ class StyleGan2UnetDiscriminator(nn.Module): x = up_block(x, res) dec_out = self.conv_out(x) - return dec_out + return dec_out, enc_out + + +def warmup(start, end, max_steps, current_step): + if current_step > max_steps: + return end + return (end - start) * (current_step / max_steps) + start + + +def mask_src_tgt(source, target, mask): + return source * mask + (1 - mask) * target + + +def cutmix(source, target, coors, alpha = 1.): + source, target = map(torch.clone, (source, target)) + ((y0, y1), (x0, x1)), _ = coors + source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1] + return source + + +def cutmix_coordinates(height, width, alpha = 1.): + lam = np.random.beta(alpha, alpha) + + cx = np.random.uniform(0, width) + cy = np.random.uniform(0, height) + w = width * np.sqrt(1 - lam) + h = height * np.sqrt(1 - lam) + x0 = int(np.round(max(cx - w / 2, 0))) + x1 = int(np.round(min(cx + w / 2, width))) + y0 = int(np.round(max(cy - h / 2, 0))) + y1 = int(np.round(min(cy + h / 2, height))) + + return ((y0, y1), (x0, x1)), lam + + +class StyleGan2UnetDivergenceLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.real = opt['real'] + self.fake = opt['fake'] + self.discriminator = opt['discriminator'] + self.for_gen = opt['gen_loss'] + self.gp_frequency = opt['gradient_penalty_frequency'] + self.noise = opt['noise'] if 'noise' in opt.keys() else 0 + self.image_size = opt['image_size'] + + def forward(self, net, state): + real_input = state[self.real] + fake_input = state[self.fake] + if self.noise != 0: + fake_input = fake_input + torch.rand_like(fake_input) * self.noise + real_input = real_input + torch.rand_like(real_input) * self.noise + + D = self.env['discriminators'][self.discriminator] + fake_dec, fake_enc = D(fake_input) + fake_aug_images = D.aug_images + if self.for_gen: + return fake_enc.mean() + F.relu(1 + fake_dec).mean() + else: + dec_loss_coef = warmup(0, 1., 30000, self.env['step']) + cutmix_prob = warmup(0, 0.25, 30000, self.env['step']) + apply_cutmix = random() < cutmix_prob + + real_input.requires_grad_() # <-- Needed to compute gradients on the input. + real_dec, real_enc = D(real_input) + real_aug_images = D.aug_images + enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean() + dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean() + divergence_loss = enc_divergence + dec_divergence * dec_loss_coef + + if apply_cutmix: + mask = cutmix( + torch.ones_like(real_dec), + torch.zeros_like(real_dec), + cutmix_coordinates(self.image_size, self.image_size) + ) + + if random() > 0.5: + mask = 1 - mask + + cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask) + cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images) + + cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean() + cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean() + disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence + + cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask) + cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight + self.last_cr_loss = cr_loss.clone().detach().item() + + disc_loss = disc_loss + cr_loss * dec_loss_coef + + # Apply gradient penalty. TODO: migrate this elsewhere. + if self.env['step'] % self.gp_frequency == 0: + from models.archs.stylegan.stylegan2 import gradient_penalty + gp = gradient_penalty(real_input, real) + self.metrics.append(("gradient_penalty", gp.clone().detach())) + disc_loss = disc_loss + gp + + real_input.requires_grad_(requires_grad=False) + return disc_loss \ No newline at end of file diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 639672cf..ce08eb26 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -15,6 +15,9 @@ def create_loss(opt_loss, env): if 'teco_' in type: from models.steps.tecogan_losses import create_teco_loss return create_teco_loss(opt_loss, env) + elif 'stylegan2_' in type: + from models.archs.stylegan import create_stylegan2_loss + return create_stylegan2_loss(opt_loss, env) elif type == 'pix': return PixLoss(opt_loss, env) elif type == 'direct': @@ -37,10 +40,6 @@ def create_loss(opt_loss, env): return RecurrentLoss(opt_loss, env) elif type == 'for_element': return ForElementLoss(opt_loss, env) - elif type == 'stylegan2_divergence': - return StyleGan2DivergenceLoss(opt_loss, env) - elif type == 'stylegan2_pathlen': - return StyleGan2PathLengthLoss(opt_loss, env) else: raise NotImplementedError @@ -487,68 +486,3 @@ class ForElementLoss(ConfigurableLoss): def clear_metrics(self): self.loss.clear_metrics() - - -class StyleGan2DivergenceLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.real = opt['real'] - self.fake = opt['fake'] - self.discriminator = opt['discriminator'] - self.for_gen = opt['gen_loss'] - self.gp_frequency = opt['gradient_penalty_frequency'] - self.noise = opt['noise'] if 'noise' in opt.keys() else 0 - - def forward(self, net, state): - real_input = state[self.real] - fake_input = state[self.fake] - if self.noise != 0: - fake_input = fake_input + torch.rand_like(fake_input) * self.noise - real_input = real_input + torch.rand_like(real_input) * self.noise - - D = self.env['discriminators'][self.discriminator] - fake = D(fake_input) - if self.for_gen: - return fake.mean() - else: - real_input.requires_grad_() # <-- Needed to compute gradients on the input. - real = D(real_input) - divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() - - # Apply gradient penalty. TODO: migrate this elsewhere. - if self.env['step'] % self.gp_frequency == 0: - from models.archs.stylegan.stylegan2 import gradient_penalty - gp = gradient_penalty(real_input, real) - self.metrics.append(("gradient_penalty", gp.clone().detach())) - divergence_loss = divergence_loss + gp - - real_input.requires_grad_(requires_grad=False) - return divergence_loss - - -class StyleGan2PathLengthLoss(ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.w_styles = opt['w_styles'] - self.gen = opt['gen'] - self.pl_mean = None - from models.archs.stylegan.stylegan2 import EMA - self.pl_length_ma = EMA(.99) - - def forward(self, net, state): - w_styles = state[self.w_styles] - gen = state[self.gen] - from models.archs.stylegan.stylegan2 import calc_pl_lengths - pl_lengths = calc_pl_lengths(w_styles, gen) - avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) - - from models.archs.stylegan.stylegan2 import is_empty - if not is_empty(self.pl_mean): - pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() - if not torch.isnan(pl_loss): - return pl_loss - else: - print("Path length loss returned NaN!") - - self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) - return 0