Move stylegan2 around, bring in unet
This commit is contained in:
parent
4c6b14a3f8
commit
5cade6b874
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from models.archs.stylegan2 import exists
|
from models.archs.stylegan.stylegan2 import exists
|
||||||
|
|
||||||
|
|
||||||
def convert_transparent_to_rgb(image):
|
def convert_transparent_to_rgb(image):
|
||||||
|
|
124
codes/models/archs/stylegan/stylegan2_unet_disc.py
Normal file
124
codes/models/archs/stylegan/stylegan2_unet_disc.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
from functools import partial
|
||||||
|
from math import log2
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def leaky_relu(p=0.2):
|
||||||
|
return nn.LeakyReLU(p)
|
||||||
|
|
||||||
|
|
||||||
|
def double_conv(chan_in, chan_out):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(chan_in, chan_out, 3, padding=1),
|
||||||
|
leaky_relu(),
|
||||||
|
nn.Conv2d(chan_out, chan_out, 3, padding=1),
|
||||||
|
leaky_relu()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock(nn.Module):
|
||||||
|
def __init__(self, input_channels, filters, downsample=True):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
||||||
|
|
||||||
|
self.net = double_conv(input_channels, filters)
|
||||||
|
self.down = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.conv_res(x)
|
||||||
|
x = self.net(x)
|
||||||
|
unet_res = x
|
||||||
|
|
||||||
|
if self.down is not None:
|
||||||
|
x = self.down(x)
|
||||||
|
|
||||||
|
x = x + res
|
||||||
|
return x, unet_res
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock(nn.Module):
|
||||||
|
def __init__(self, input_channels, filters):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride = 2)
|
||||||
|
self.net = double_conv(input_channels, filters)
|
||||||
|
self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False)
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.filters = filters
|
||||||
|
|
||||||
|
def forward(self, x, res):
|
||||||
|
*_, h, w = x.shape
|
||||||
|
conv_res = self.conv_res(x, output_size = (h * 2, w * 2))
|
||||||
|
x = self.up(x)
|
||||||
|
x = torch.cat((x, res), dim=1)
|
||||||
|
x = self.net(x)
|
||||||
|
x = x + conv_res
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGan2UnetDiscriminator(nn.Module):
|
||||||
|
def __init__(self, image_size, network_capacity = 16, fmap_max = 512, input_filters=3):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = int(log2(image_size) - 3)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
filters = [input_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
|
||||||
|
|
||||||
|
set_fmap_max = partial(min, fmap_max)
|
||||||
|
filters = list(map(set_fmap_max, filters))
|
||||||
|
filters[-1] = filters[-2]
|
||||||
|
|
||||||
|
chan_in_out = list(zip(filters[:-1], filters[1:]))
|
||||||
|
chan_in_out = list(map(list, chan_in_out))
|
||||||
|
|
||||||
|
down_blocks = []
|
||||||
|
attn_blocks = []
|
||||||
|
|
||||||
|
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
|
||||||
|
num_layer = ind + 1
|
||||||
|
is_not_last = ind != (len(chan_in_out) - 1)
|
||||||
|
|
||||||
|
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
|
||||||
|
down_blocks.append(block)
|
||||||
|
|
||||||
|
attn_fn = attn_and_ff(out_chan)
|
||||||
|
attn_blocks.append(attn_fn)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
self.attn_blocks = nn.ModuleList(attn_blocks)
|
||||||
|
|
||||||
|
last_chan = filters[-1]
|
||||||
|
|
||||||
|
self.to_logit = nn.Sequential(
|
||||||
|
leaky_relu(),
|
||||||
|
nn.AvgPool2d(image_size // (2 ** num_layers)),
|
||||||
|
Flatten(1),
|
||||||
|
nn.Linear(last_chan, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv = double_conv(last_chan, last_chan)
|
||||||
|
|
||||||
|
dec_chan_in_out = chan_in_out[:-1][::-1]
|
||||||
|
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
|
||||||
|
self.conv_out = nn.Conv2d(3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, *_ = x.shape
|
||||||
|
|
||||||
|
residuals = []
|
||||||
|
|
||||||
|
for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
|
||||||
|
x, unet_res = down_block(x)
|
||||||
|
residuals.append(unet_res)
|
||||||
|
|
||||||
|
if attn_block is not None:
|
||||||
|
x = attn_block(x)
|
||||||
|
|
||||||
|
x = self.conv(x) + x
|
||||||
|
enc_out = self.to_logit(x)
|
||||||
|
|
||||||
|
for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
|
||||||
|
x = up_block(x, res)
|
||||||
|
|
||||||
|
dec_out = self.conv_out(x)
|
||||||
|
return enc_out.squeeze(), dec_out
|
|
@ -22,7 +22,8 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
||||||
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
||||||
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
||||||
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
||||||
from models.archs.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor
|
from models.archs.stylegan.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor
|
||||||
|
from models.archs.stylegan.stylegan2_unet_disc import StyleGan2UnetDiscriminator
|
||||||
from models.archs.teco_resgen import TecoGen
|
from models.archs.teco_resgen import TecoGen
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -200,6 +201,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
|
elif which_model == "stylegan2_unet":
|
||||||
|
disc = StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
||||||
|
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
|
@ -517,7 +517,7 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||||
|
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
from models.archs.stylegan2 import gradient_penalty
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
gp = gradient_penalty(real_input, real)
|
gp = gradient_penalty(real_input, real)
|
||||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
divergence_loss = divergence_loss + gp
|
divergence_loss = divergence_loss + gp
|
||||||
|
@ -532,17 +532,17 @@ class StyleGan2PathLengthLoss(ConfigurableLoss):
|
||||||
self.w_styles = opt['w_styles']
|
self.w_styles = opt['w_styles']
|
||||||
self.gen = opt['gen']
|
self.gen = opt['gen']
|
||||||
self.pl_mean = None
|
self.pl_mean = None
|
||||||
from models.archs.stylegan2 import EMA
|
from models.archs.stylegan.stylegan2 import EMA
|
||||||
self.pl_length_ma = EMA(.99)
|
self.pl_length_ma = EMA(.99)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
w_styles = state[self.w_styles]
|
w_styles = state[self.w_styles]
|
||||||
gen = state[self.gen]
|
gen = state[self.gen]
|
||||||
from models.archs.stylegan2 import calc_pl_lengths
|
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
||||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||||
|
|
||||||
from models.archs.stylegan2 import is_empty
|
from models.archs.stylegan.stylegan2 import is_empty
|
||||||
if not is_empty(self.pl_mean):
|
if not is_empty(self.pl_mean):
|
||||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||||
if not torch.isnan(pl_loss):
|
if not torch.isnan(pl_loss):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user