forked from mrq/DL-Art-School
More circular dependency fixes + unet fixes
This commit is contained in:
parent
e587d549f7
commit
98eada1e4c
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from models.archs.stylegan.stylegan2 import exists
|
import models.archs.stylegan.stylegan2 as sg2
|
||||||
|
|
||||||
|
|
||||||
def convert_transparent_to_rgb(image):
|
def convert_transparent_to_rgb(image):
|
||||||
|
@ -61,7 +61,7 @@ class expand_greyscale(object):
|
||||||
else:
|
else:
|
||||||
raise Exception(f'image with invalid number of channels given {channels}')
|
raise Exception(f'image with invalid number of channels given {channels}')
|
||||||
|
|
||||||
if not exists(alpha) and self.transparent:
|
if not sg2.exists(alpha) and self.transparent:
|
||||||
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
|
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
|
||||||
|
|
||||||
return color if not self.transparent else torch.cat((color, alpha))
|
return color if not self.transparent else torch.cat((color, alpha))
|
||||||
|
|
|
@ -7,13 +7,14 @@ from random import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import models.steps.losses as L
|
||||||
|
|
||||||
from kornia.filters import filter2D
|
from kornia.filters import filter2D
|
||||||
from linear_attention_transformer import ImageLinearAttention
|
from linear_attention_transformer import ImageLinearAttention
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import grad as torch_grad
|
from torch.autograd import grad as torch_grad
|
||||||
from vector_quantize_pytorch import VectorQuantize
|
from vector_quantize_pytorch import VectorQuantize
|
||||||
|
|
||||||
from models.steps.losses import ConfigurableLoss
|
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -700,7 +701,7 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.real = opt['real']
|
self.real = opt['real']
|
||||||
|
@ -737,7 +738,7 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||||
return divergence_loss
|
return divergence_loss
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.w_styles = opt['w_styles']
|
self.w_styles = opt['w_styles']
|
||||||
|
|
|
@ -6,9 +6,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import models.archs.stylegan.stylegan2 as sg2
|
||||||
from models.archs.stylegan.stylegan2 import attn_and_ff
|
import models.steps.losses as L
|
||||||
from models.steps.losses import ConfigurableLoss
|
|
||||||
|
|
||||||
|
|
||||||
def leaky_relu(p=0.2):
|
def leaky_relu(p=0.2):
|
||||||
|
@ -96,7 +95,7 @@ class StyleGan2UnetDiscriminator(nn.Module):
|
||||||
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
|
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
|
||||||
down_blocks.append(block)
|
down_blocks.append(block)
|
||||||
|
|
||||||
attn_fn = attn_and_ff(out_chan)
|
attn_fn = sg2.attn_and_ff(out_chan)
|
||||||
attn_blocks.append(attn_fn)
|
attn_blocks.append(attn_fn)
|
||||||
|
|
||||||
self.down_blocks = nn.ModuleList(down_blocks)
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
@ -171,7 +170,7 @@ def cutmix_coordinates(height, width, alpha = 1.):
|
||||||
return ((y0, y1), (x0, x1)), lam
|
return ((y0, y1), (x0, x1)), lam
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
class StyleGan2UnetDivergenceLoss(L.ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.real = opt['real']
|
self.real = opt['real']
|
||||||
|
@ -181,6 +180,7 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
self.image_size = opt['image_size']
|
self.image_size = opt['image_size']
|
||||||
|
self.cr_weight = .2
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
real_input = state[self.real]
|
real_input = state[self.real]
|
||||||
|
@ -191,7 +191,7 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
|
|
||||||
D = self.env['discriminators'][self.discriminator]
|
D = self.env['discriminators'][self.discriminator]
|
||||||
fake_dec, fake_enc = D(fake_input)
|
fake_dec, fake_enc = D(fake_input)
|
||||||
fake_aug_images = D.aug_images
|
fake_aug_images = D.module.aug_images
|
||||||
if self.for_gen:
|
if self.for_gen:
|
||||||
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
|
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
|
||||||
else:
|
else:
|
||||||
|
@ -201,10 +201,10 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
|
|
||||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||||
real_dec, real_enc = D(real_input)
|
real_dec, real_enc = D(real_input)
|
||||||
real_aug_images = D.aug_images
|
real_aug_images = D.module.aug_images
|
||||||
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
|
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
|
||||||
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
|
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
|
||||||
divergence_loss = enc_divergence + dec_divergence * dec_loss_coef
|
disc_loss = enc_divergence + dec_divergence * dec_loss_coef
|
||||||
|
|
||||||
if apply_cutmix:
|
if apply_cutmix:
|
||||||
mask = cutmix(
|
mask = cutmix(
|
||||||
|
@ -217,11 +217,11 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
mask = 1 - mask
|
mask = 1 - mask
|
||||||
|
|
||||||
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
|
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
|
||||||
cutmix_enc_out, cutmix_dec_out = self.GAN.D(cutmix_images)
|
cutmix_dec_out, cutmix_enc_out = D.module.D(cutmix_images) # Bypass implied augmentor - hence D.module.D
|
||||||
|
|
||||||
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
|
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
|
||||||
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
|
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
|
||||||
disc_loss = divergence_loss + cutmix_enc_divergence + cutmix_dec_divergence
|
disc_loss = disc_loss + cutmix_enc_divergence + cutmix_dec_divergence
|
||||||
|
|
||||||
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
|
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
|
||||||
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
|
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
|
||||||
|
@ -232,9 +232,12 @@ class StyleGan2UnetDivergenceLoss(ConfigurableLoss):
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
gp = gradient_penalty(real_input, real)
|
if random() < .5:
|
||||||
|
gp = gradient_penalty(real_input, real_enc)
|
||||||
|
else:
|
||||||
|
gp = gradient_penalty(real_input, real_dec) * dec_loss_coef
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
disc_loss = disc_loss + gp
|
disc_loss = disc_loss + gp
|
||||||
|
|
||||||
real_input.requires_grad_(requires_grad=False)
|
real_input.requires_grad_(requires_grad=False)
|
||||||
return disc_loss
|
return disc_loss
|
||||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.networks import define_F
|
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
import random
|
import random
|
||||||
import functools
|
import functools
|
||||||
|
@ -130,7 +129,8 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
super(FeatureLoss, self).__init__(opt, env)
|
super(FeatureLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
self.netF = define_F(which_model=opt['which_model_F'],
|
import models.networks
|
||||||
|
self.netF = models.networks.define_F(which_model=opt['which_model_F'],
|
||||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
self.netF = torch.nn.parallel.DataParallel(self.netF, device_ids=env['opt']['gpu_ids'])
|
||||||
|
@ -155,8 +155,9 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
||||||
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
import models.networks
|
||||||
self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
self.netF_real = models.networks.define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||||
|
self.netF_gen = models.networks.define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user