diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index f44f1ddf..2424f686 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms import torch.nn as nn from pathlib import Path -from models.archs.stylegan.stylegan2 import exists +import models.archs.stylegan.stylegan2 as sg2 def convert_transparent_to_rgb(image): @@ -61,7 +61,7 @@ class expand_greyscale(object): else: raise Exception(f'image with invalid number of channels given {channels}') - if not exists(alpha) and self.transparent: + if not sg2.exists(alpha) and self.transparent: alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) return color if not self.transparent else torch.cat((color, alpha)) diff --git a/codes/models/archs/stylegan/stylegan2.py b/codes/models/archs/stylegan/stylegan2.py index adba029f..f29475ce 100644 --- a/codes/models/archs/stylegan/stylegan2.py +++ b/codes/models/archs/stylegan/stylegan2.py @@ -7,13 +7,14 @@ from random import random import torch import torch.nn.functional as F +import models.steps.losses as L + from kornia.filters import filter2D from linear_attention_transformer import ImageLinearAttention from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize -from models.steps.losses import ConfigurableLoss from utils.util import checkpoint try: @@ -700,7 +701,7 @@ class StyleGan2Discriminator(nn.Module): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') -class StyleGan2DivergenceLoss(ConfigurableLoss): +class StyleGan2DivergenceLoss(L.ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.real = opt['real'] @@ -737,7 +738,7 @@ class StyleGan2DivergenceLoss(ConfigurableLoss): return divergence_loss -class StyleGan2PathLengthLoss(ConfigurableLoss): +class StyleGan2PathLengthLoss(L.ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.w_styles = opt['w_styles'] diff --git a/codes/models/archs/stylegan/stylegan2_unet_disc.py b/codes/models/archs/stylegan/stylegan2_unet_disc.py index a4c791d0..4b319879 100644 --- a/codes/models/archs/stylegan/stylegan2_unet_disc.py +++ b/codes/models/archs/stylegan/stylegan2_unet_disc.py @@ -6,9 +6,8 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from models.archs.stylegan.stylegan2 import attn_and_ff -from models.steps.losses import ConfigurableLoss +import models.archs.stylegan.stylegan2 as sg2 +import models.steps.losses as L def leaky_relu(p=0.2): @@ -96,7 +95,7 @@ class StyleGan2UnetDiscriminator(nn.Module): block = DownBlock(in_chan, out_chan, downsample = is_not_last) down_blocks.append(block) - attn_fn = attn_and_ff(out_chan) + attn_fn = sg2.attn_and_ff(out_chan) attn_blocks.append(attn_fn) self.down_blocks = nn.ModuleList(down_blocks) @@ -171,7 +170,7 @@ def cutmix_coordinates(height, width, alpha = 1.): return ((y0, y1), (x0, x1)), lam -class StyleGan2UnetDivergenceLoss(ConfigurableLoss): +class StyleGan2UnetDivergenceLoss(L.ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.real = opt['real'] @@ -181,6 +180,7 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss): self.gp_frequency = opt['gradient_penalty_frequency'] self.noise = opt['noise'] if 'noise' in opt.keys() else 0 self.image_size = opt['image_size'] + self.cr_weight = .2 def forward(self, net, state): real_input = state[self.real] @@ -191,7 +191,7 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss): D = self.env['discriminators'][self.discriminator] fake_dec, fake_enc = D(fake_input) - fake_aug_images = D.aug_images + fake_aug_images = D.module.aug_images if self.for_gen: return fake_enc.mean() + F.relu(1 + fake_dec).mean() else: @@ -201,10 +201,10 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss): real_input.requires_grad_() # <-- Needed to compute gradients on the input. real_dec, real_enc = D(real_input) - real_aug_images = D.aug_images + real_aug_images = D.module.aug_images enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean() dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean() - divergence_loss = enc_divergence + dec_divergence * dec_loss_coef + disc_loss = enc_divergence + dec_divergence * dec_loss_coef if apply_cutmix: mask = cutmix( @@ -217,11 +217,11 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss): mask = 1 - mask cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask) - cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images) + cutmix_dec_out, cutmix_enc_out = D.module.D(cutmix_images) # Bypass implied augmentor - hence D.module.D cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean() cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean() - disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence + disc_loss = disc_loss + cutmix_enc_divergence + cutmix_dec_divergence cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask) cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight @@ -232,9 +232,12 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: from models.archs.stylegan.stylegan2 import gradient_penalty - gp = gradient_penalty(real_input, real) + if random() < .5: + gp = gradient_penalty(real_input, real_enc) + else: + gp = gradient_penalty(real_input, real_dec) * dec_loss_coef self.metrics.append(("gradient_penalty", gp.clone().detach())) disc_loss = disc_loss + gp real_input.requires_grad_(requires_grad=False) - return disc_loss \ No newline at end of file + return disc_loss diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index ce08eb26..8364b1a6 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn from torch.cuda.amp import autocast -from models.networks import define_F from models.loss import GANLoss import random import functools @@ -130,7 +129,8 @@ class FeatureLoss(ConfigurableLoss): super(FeatureLoss, self).__init__(opt, env) self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) - self.netF = define_F(which_model=opt['which_model_F'], + import models.networks + self.netF = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device']) if not env['opt']['dist']: self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids']) @@ -155,8 +155,9 @@ class InterpretedFeatureLoss(ConfigurableLoss): super(InterpretedFeatureLoss, self).__init__(opt, env) self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) - self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device']) - self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device']) + import models.networks + self.netF_real = models.networks.define_F(which_model=opt['which_model_F']).to(self.env['device']) + self.netF_gen = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device']) if not env['opt']['dist']: self.netF_real = torch.nn.parallel.DataParallel(self.netF_real) self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)