DL-Art-School/codes/models/stylegan/stylegan2_lucidrains.py

964 lines
33 KiB
Python

import functools
import math
import multiprocessing
from contextlib import contextmanager, ExitStack
from functools import partial
from math import log2
from random import random
import torch
import torch.nn.functional as F
import trainer.losses as L
import numpy as np
from kornia.filters import filter2D
from linear_attention_transformer import ImageLinearAttention
from torch import nn
from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
num_cores = multiprocessing.cpu_count()
# constants
EPS = 1e-8
CALC_FID_NUM_IMAGES = 12800
# helper classes
def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
class NanException(Exception):
pass
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if not exists(old):
return new
return old * self.beta + (1 - self.beta) * new
class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.shape[0], -1)
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
class PermuteToFrom(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = x.permute(0, 2, 3, 1)
out, loss = self.fn(x)
out = out.permute(0, 3, 1, 2)
return out, loss
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2D(x, f, normalized=True)
# one layer of self-attention and feedforward, for images
attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])
# helpers
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def combine_contexts(contexts):
@contextmanager
def multi_contexts():
with ExitStack() as stack:
yield [stack.enter_context(ctx()) for ctx in contexts]
return multi_contexts
def default(value, d):
return value if exists(value) else d
def cycle(iterable):
while True:
for i in iterable:
yield i
def cast_list(el):
return el if isinstance(el, list) else [el]
def is_empty(t):
if isinstance(t, torch.Tensor):
return t.nelement() == 0
return not exists(t)
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
if is_ddp:
num_no_syncs = gradient_accumulate_every - 1
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
tail = [null_context]
contexts = head + tail
else:
contexts = [null_context] * gradient_accumulate_every
for context in contexts:
with context():
yield
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
def gradient_penalty(images, output, weight=10, return_structured_grads=False):
batch_size = images.shape[0]
gradients = torch_grad(outputs=output, inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
flat_grad = gradients.reshape(batch_size, -1)
penalty = weight * ((flat_grad.norm(2, dim=1) - 1) ** 2).mean()
if return_structured_grads:
return penalty, gradients
else:
return penalty
def calc_pl_lengths(styles, images):
device = images.device
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape, device=device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
def image_noise(n, im_size, device):
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
def leaky_relu(p=0.2):
return nn.LeakyReLU(p, inplace=True)
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
# augmentations
def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))
class StyleGan2Augmentor(nn.Module):
def __init__(self, D, image_size, types, prob):
super().__init__()
self.D = D
self.prob = prob
self.types = types
def forward(self, images, detach=False):
if random() < self.prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types)
if detach:
images = images.detach()
# Save away for use elsewhere (e.g. unet loss)
self.aug_images = images
return self.D(images)
def network_loaded(self):
self.D.network_loaded()
# stylegan2 classes
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim))
self.lr_mul = lr_mul
def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleVectorizer(nn.Module):
def __init__(self, emb, depth, lr_mul=0.1):
super().__init__()
layers = []
for i in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()])
self.net = nn.Sequential(*layers)
def forward(self, x):
x = F.normalize(x, dim=1)
return self.net(x)
class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
Blur()
) if upsample else None
def forward(self, x, prev_rgb, istyle):
b, c, h, w = x.shape
style = self.to_style(istyle)
x = self.conv(x, style)
if exists(prev_rgb):
x = x + prev_rgb
if exists(self.upsample):
x = self.upsample(x)
return x
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
from models.archs.arch_util import ConvGnLelu
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
self.norm = nn.InstanceNorm2d(in_channel)
def forward(self, input, style):
gamma = self.style2scale(style)
beta = self.style2bias(style)
out = self.norm(input)
out = gamma * out + beta
return out
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * math.sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
super().__init__()
self.filters = out_chan
self.demod = demod
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
def forward(self, x, y):
b, c, h, w = x.shape
w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
weights = weights * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.filters, *ws)
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
x = F.conv2d(x, weights, padding=padding, groups=b)
x = x.reshape(-1, self.filters, h, w)
return x
class GeneratorBlockWithStructure(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = nn.Linear(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = nn.Linear(latent_dim, filters)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
def forward(self, x, prev_rgb, istyle, inoise, structure_input):
if exists(self.upsample):
x = self.upsample(x)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2))
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
x = self.conv0(x)
x = self.noise0(x, noise0)
x = self.adain0(x, structure)
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.structure_input = structure_input
if self.structure_input:
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2
self.to_style1 = nn.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None):
if exists(self.upsample):
x = self.upsample(x)
if self.structure_input:
s = self.structure_conv(structure_input)
x = torch.cat([x, s], dim=1)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
fmap_max=512, structure_input=False):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]
in_out_pairs = zip(filters[:-1], filters[1:])
self.no_const = no_const
if no_const:
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False)
else:
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4)))
self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1)
self.blocks = nn.ModuleList([])
self.attns = nn.ModuleList([])
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
num_layer = self.num_layers - ind
attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None
self.attns.append(attn_fn)
if structure_input:
block_fn = GeneratorBlockWithStructure
else:
block_fn = GeneratorBlock
block = block_fn(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
)
self.blocks.append(block)
def forward(self, styles, input_noise, structure_input=None, starting_shape=None):
batch_size = styles.shape[0]
image_size = self.image_size
if self.no_const:
avg_style = styles.mean(dim=1)[:, :, None, None]
x = self.to_initial_block(avg_style)
else:
x = self.initial_block.expand(batch_size, -1, -1, -1)
if starting_shape is not None:
x = F.interpolate(x, size=starting_shape, mode="bilinear")
rgb = None
styles = styles.transpose(0, 1)
x = self.initial_conv(x)
if structure_input is not None:
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
for style, block, attn in zip(styles, self.blocks, self.attns):
if exists(attn):
x = checkpoint(attn, x)
if structure_input is not None:
if exists(block.upsample):
# In this case, the structural guidance is given by the extra information over the previous layer.
twoX = (x.shape[2]*2, x.shape[3]*2)
sn = torch.nn.functional.interpolate(structure_input, size=twoX, mode="nearest")
s_int = torch.nn.functional.interpolate(s, size=twoX, mode="bilinear")
s_diff = sn - s_int
else:
# This is the initial case - just feed in the base structure.
s_diff = s
else:
s_diff = None
x, rgb = checkpoint(block, x, rgb, style, input_noise, s_diff)
return rgb
# Wrapper that combines style vectorizer with the actual generator.
class StyleGan2GeneratorWithLatent(nn.Module):
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False,
attn_layers=[], no_const=False, fmap_max=512, structure_input=False):
super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max,
structure_input=structure_input)
self.mixed_prob = .9
self._init_weights()
def noise(self, n, latent_dim, device):
return torch.randn(n, latent_dim).cuda(device)
def noise_list(self, n, layers, latent_dim, device):
return [(self.noise(n, latent_dim, device), layers)]
def mixed_list(self, n, layers, latent_dim, device):
tt = int(torch.rand(()).numpy() * layers)
return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device)
def latent_to_w(self, style_vectorizer, latent_descr):
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
def styles_def_to_tensor(self, styles_def):
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
# b,f,h,w.
def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False):
b, f, h, w = x.shape
full_random_latents = True
if full_random_latents:
style = self.noise(b*2, self.gen.latent_dim, x.device)
w = self.vectorizer(style)
# Randomly distribute styles across layers
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
for j in range(b):
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
w_styles[j] = w_styles[j*2]
else:
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
w_styles = w_styles[:b]
else:
get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list
style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device)
w_space = self.latent_to_w(self.vectorizer, style)
w_styles = self.styles_def_to_tensor(w_space)
starting_shape = None
if fit_starting_shape_to_structure:
starting_shape = (x.shape[2] // 32, x.shape[3] // 32)
# The underlying model expects the noise as b,h,w,1. Make it so.
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
nn.init.zeros_(block.to_noise1.weight)
nn.init.zeros_(block.to_noise2.weight)
nn.init.zeros_(block.to_noise1.bias)
nn.init.zeros_(block.to_noise2.bias)
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.filters = filters
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
super().__init__()
num_layers = int(log2(image_size) - 1)
blocks = []
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
chan_in_out = list(zip(filters[:-1], filters[1:]))
blocks = []
attn_blocks = []
quantize_blocks = []
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
attn_blocks.append(attn_fn)
if quantize:
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
else:
quantize_blocks.append(None)
self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
self.quantize_blocks = nn.ModuleList(quantize_blocks)
self.do_checkpointing = do_checkpointing
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten()
if mlp:
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
nn.Linear(100, 1))
else:
self.to_logit = nn.Linear(latent_dim, 1)
self._init_weights()
def forward(self, x):
b, *_ = x.shape
quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
if self.do_checkpointing:
x = checkpoint(block, x)
else:
x = block(x)
if exists(attn_block):
x = attn_block(x)
if exists(q_block):
x, _, loss = q_block(x)
quantize_loss += loss
x = self.final_conv(x)
x = self.flatten(x)
x = self.to_logit(x)
if exists(q_block):
return x.squeeze(), quantize_loss
else:
return x.squeeze()
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The haed (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
self.bypass_blocks = bypass_blocks
self.num_blocks = num_blocks
self.frozen_until_step = frozen_until_step
# Called after the network weights are loaded.
def network_loaded(self):
if not hasattr(self, 'frozen_until_step'):
return
if self.bypass_blocks > 0:
self.blocks = self.blocks[self.bypass_blocks:]
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
reset_blocks = [self.to_logit]
for i in range(self.num_blocks):
reset_blocks.append(self.blocks[i])
for bl in reset_blocks:
for m in bl.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True):
p._NEW_BLOCK = True
for p in self.parameters():
if not hasattr(p, '_NEW_BLOCK'):
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
class StyleGan2DivergenceLoss(L.ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp
real_input.requires_grad_(requires_grad=False)
return divergence_loss
class StyleGan2PathLengthLoss(L.ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0
@register_model
def register_stylegan2_lucidrains(opt_net, opt):
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)
@register_model
def register_stylegan2_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True))
if 'use_partial_pretrained' in opt_net.keys():
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])