forked from mrq/DL-Art-School
Commit my attempt at "conforming" the lucidrains stylegan implementation to the reference spec. Not working. will probably be abandoned.
This commit is contained in:
parent
209332292a
commit
1708136b55
847
codes/models/archs/stylegan/stylegan2_lucidrains_conformed.py
Normal file
847
codes/models/archs/stylegan/stylegan2_lucidrains_conformed.py
Normal file
|
@ -0,0 +1,847 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import trainer.losses as L
|
||||
import numpy as np
|
||||
|
||||
from kornia.filters import filter2D
|
||||
from linear_attention_transformer import ImageLinearAttention
|
||||
from torch import nn
|
||||
from torch.autograd import grad as torch_grad
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from utils.util import checkpoint
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
||||
APEX_AVAILABLE = True
|
||||
except:
|
||||
APEX_AVAILABLE = False
|
||||
|
||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
||||
|
||||
num_cores = multiprocessing.cpu_count()
|
||||
|
||||
# constants
|
||||
|
||||
EPS = 1e-8
|
||||
CALC_FID_NUM_IMAGES = 12800
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
def DiffAugment(x, types=[]):
|
||||
for p in types:
|
||||
for f in AUGMENT_FNS[p]:
|
||||
x = f(x)
|
||||
return x.contiguous()
|
||||
|
||||
def rand_brightness(x):
|
||||
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
||||
return x
|
||||
|
||||
def rand_saturation(x):
|
||||
x_mean = x.mean(dim=1, keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
||||
return x
|
||||
|
||||
def rand_contrast(x):
|
||||
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
||||
return x
|
||||
|
||||
def rand_translation(x, ratio=0.125):
|
||||
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
||||
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
||||
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
||||
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def rand_cutout(x, ratio=0.5):
|
||||
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
||||
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
||||
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
||||
mask[grid_batch, grid_x, grid_y] = 0
|
||||
x = x * mask.unsqueeze(1)
|
||||
return x
|
||||
|
||||
AUGMENT_FNS = {
|
||||
'color': [rand_brightness, rand_saturation, rand_contrast],
|
||||
'translation': [rand_translation],
|
||||
'cutout': [rand_cutout],
|
||||
}
|
||||
|
||||
class NanException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.reshape(x.shape[0], -1)
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) + x
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) * self.g
|
||||
|
||||
|
||||
class PermuteToFrom(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
out, loss = self.fn(x)
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
return out, loss
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
f = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('f', f)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.f
|
||||
f = f[None, None, :] * f[None, :, None]
|
||||
return filter2D(x, f, normalized=True)
|
||||
|
||||
|
||||
# one layer of self-attention and feedforward, for images
|
||||
|
||||
attn_and_ff = lambda chan: nn.Sequential(*[
|
||||
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
|
||||
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
|
||||
])
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def null_context():
|
||||
yield
|
||||
|
||||
|
||||
def combine_contexts(contexts):
|
||||
@contextmanager
|
||||
def multi_contexts():
|
||||
with ExitStack() as stack:
|
||||
yield [stack.enter_context(ctx()) for ctx in contexts]
|
||||
|
||||
return multi_contexts
|
||||
|
||||
|
||||
def default(value, d):
|
||||
return value if exists(value) else d
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
while True:
|
||||
for i in iterable:
|
||||
yield i
|
||||
|
||||
|
||||
def cast_list(el):
|
||||
return el if isinstance(el, list) else [el]
|
||||
|
||||
|
||||
def is_empty(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.nelement() == 0
|
||||
return not exists(t)
|
||||
|
||||
|
||||
def raise_if_nan(t):
|
||||
if torch.isnan(t):
|
||||
raise NanException
|
||||
|
||||
|
||||
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
|
||||
if is_ddp:
|
||||
num_no_syncs = gradient_accumulate_every - 1
|
||||
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
|
||||
tail = [null_context]
|
||||
contexts = head + tail
|
||||
else:
|
||||
contexts = [null_context] * gradient_accumulate_every
|
||||
|
||||
for context in contexts:
|
||||
with context():
|
||||
yield
|
||||
|
||||
|
||||
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
||||
if fp16:
|
||||
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
||||
scaled_loss.backward(**kwargs)
|
||||
else:
|
||||
loss.backward(**kwargs)
|
||||
|
||||
|
||||
def gradient_penalty(images, output, weight=10):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs=output, inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
gradients = gradients.reshape(batch_size, -1)
|
||||
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
|
||||
|
||||
def calc_pl_lengths(styles, images):
|
||||
device = images.device
|
||||
num_pixels = images.shape[2] * images.shape[3]
|
||||
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
||||
outputs = (images * pl_noise).sum()
|
||||
|
||||
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
||||
grad_outputs=torch.ones(outputs.shape, device=device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
||||
|
||||
|
||||
def image_noise(n, im_size, device):
|
||||
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
|
||||
|
||||
|
||||
class BiasedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return biased_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def biased_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
|
||||
def evaluate_in_chunks(max_batch_size, model, *args):
|
||||
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
|
||||
chunked_outputs = [model(*i) for i in split_args]
|
||||
if len(chunked_outputs) == 1:
|
||||
return chunked_outputs[0]
|
||||
return torch.cat(chunked_outputs, dim=0)
|
||||
|
||||
|
||||
def set_requires_grad(model, bool):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = bool
|
||||
|
||||
|
||||
def slerp(val, low, high):
|
||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
# augmentations
|
||||
|
||||
def random_hflip(tensor, prob):
|
||||
if prob > random():
|
||||
return tensor
|
||||
return torch.flip(tensor, dims=(3,))
|
||||
|
||||
|
||||
class StyleGan2Augmentor(nn.Module):
|
||||
def __init__(self, D, image_size, types, prob):
|
||||
super().__init__()
|
||||
self.D = D
|
||||
self.prob = prob
|
||||
self.types = types
|
||||
|
||||
def forward(self, images, detach=False):
|
||||
if random() < self.prob:
|
||||
images = random_hflip(images, prob=0.5)
|
||||
images = DiffAugment(images, types=self.types)
|
||||
|
||||
if detach:
|
||||
images = images.detach()
|
||||
|
||||
# Save away for use elsewhere (e.g. unet loss)
|
||||
self.aug_images = images
|
||||
|
||||
return self.D(images)
|
||||
|
||||
|
||||
# stylegan2 classes
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, activation=False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim))
|
||||
|
||||
self.lr_mul = lr_mul
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
||||
else:
|
||||
out = F.linear(input, self.weight * self.lr_mul)
|
||||
out = biased_leaky_relu(out, self.bias * self.lr_mul)
|
||||
return out
|
||||
|
||||
|
||||
class StyleVectorizer(nn.Module):
|
||||
def __init__(self, emb, depth, lr_mul=0.01):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
layers.extend([EqualLinear(emb, emb, lr_mul, activation=True)])
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.normalize(x, dim=1)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class RGBBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
|
||||
super().__init__()
|
||||
self.input_channel = input_channel
|
||||
self.to_style = nn.Linear(latent_dim, input_channel)
|
||||
|
||||
out_filters = 3 if not rgba else 4
|
||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||
Blur()
|
||||
) if upsample else None
|
||||
|
||||
def forward(self, x, prev_rgb, istyle):
|
||||
b, c, h, w = x.shape
|
||||
style = self.to_style(istyle)
|
||||
x = self.conv(x, style)
|
||||
|
||||
if exists(prev_rgb):
|
||||
x = x + prev_rgb
|
||||
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm(nn.Module):
|
||||
def __init__(self, in_channel, style_dim):
|
||||
super().__init__()
|
||||
from models.arch_util import ConvGnLelu
|
||||
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
||||
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
||||
self.norm = nn.InstanceNorm2d(in_channel)
|
||||
|
||||
def forward(self, input, style):
|
||||
gamma = self.style2scale(style)
|
||||
beta = self.style2bias(style)
|
||||
out = self.norm(input)
|
||||
out = gamma * out + beta
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
|
||||
def forward(self, image, noise):
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class EqualLR:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def compute_weight(self, module):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
||||
|
||||
return weight * math.sqrt(2 / fan_in)
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name):
|
||||
fn = EqualLR(name)
|
||||
|
||||
weight = getattr(module, name)
|
||||
del module._parameters[name]
|
||||
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
||||
module.register_forward_pre_hook(fn)
|
||||
|
||||
return fn
|
||||
|
||||
def __call__(self, module, input):
|
||||
weight = self.compute_weight(module)
|
||||
setattr(module, self.name, weight)
|
||||
|
||||
|
||||
def equal_lr(module, name='weight'):
|
||||
EqualLR.apply(module, name)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
conv = nn.Conv2d(*args, **kwargs)
|
||||
conv.weight.data.normal_()
|
||||
conv.bias.data.zero_()
|
||||
self.conv = equal_lr(conv)
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv(input)
|
||||
|
||||
|
||||
class Conv2DMod(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
|
||||
super().__init__()
|
||||
self.filters = out_chan
|
||||
self.demod = demod
|
||||
self.kernel = kernel
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
||||
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def _get_same_padding(self, size, kernel, dilation, stride):
|
||||
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
||||
|
||||
def forward(self, x, y):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
w1 = y[:, None, :, None, None]
|
||||
w2 = self.weight[None, :, :, :, :]
|
||||
weights = w2 * (w1 + 1)
|
||||
|
||||
if self.demod:
|
||||
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
|
||||
weights = weights * d
|
||||
|
||||
x = x.reshape(1, -1, h, w)
|
||||
|
||||
_, _, *ws = weights.shape
|
||||
weights = weights.reshape(b * self.filters, *ws)
|
||||
|
||||
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
||||
x = F.conv2d(x, weights, padding=padding, groups=b)
|
||||
|
||||
x = x.reshape(-1, self.filters, h, w)
|
||||
return x
|
||||
|
||||
|
||||
class GeneratorBlock(nn.Module):
|
||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, initial_block=False):
|
||||
super().__init__()
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||
|
||||
self.to_style1 = nn.Linear(latent_dim, input_channels)
|
||||
self.noise1_scale = nn.Parameter(torch.full((1,), fill_value=1e-5))
|
||||
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
||||
self.activation1 = BiasedLeakyReLU(filters)
|
||||
|
||||
self.initial_block = initial_block
|
||||
if not initial_block:
|
||||
self.to_style2 = nn.Linear(latent_dim, filters)
|
||||
self.noise2_scale = nn.Parameter(torch.full((1,), fill_value=1e-5))
|
||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||
self.activation2 = BiasedLeakyReLU(filters)
|
||||
|
||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
||||
|
||||
def forward(self, x, prev_rgb, istyle, inoise1=None, inoise2=None):
|
||||
if exists(self.upsample):
|
||||
x = self.upsample(x)
|
||||
if inoise1 is None:
|
||||
b, c, h, w = x.shape
|
||||
inoise1 = torch.randn((b,1,h,w), device=x.device)
|
||||
inoise2 = torch.randn((b,1,h,w), device=x.device) # Assume that both are None if one is None.
|
||||
|
||||
noise1 = inoise1 * self.noise1_scale
|
||||
style1 = self.to_style1(istyle)
|
||||
x = self.conv1(x, style1)
|
||||
x = self.activation1(x + noise1)
|
||||
|
||||
if not self.initial_block:
|
||||
noise2 = inoise2 * self.noise2_scale
|
||||
style2 = self.to_style2(istyle)
|
||||
x = self.conv2(x, style2)
|
||||
x = self.activation2(x + noise2)
|
||||
|
||||
rgb = self.to_rgb(x, prev_rgb, istyle)
|
||||
return x, rgb
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
||||
fmap_max=512):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.latent_dim = latent_dim
|
||||
self.num_layers = int(log2(image_size) - 1)
|
||||
|
||||
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
|
||||
|
||||
set_fmap_max = partial(min, fmap_max)
|
||||
filters = list(map(set_fmap_max, filters))
|
||||
init_channels = filters[0]
|
||||
filters = [init_channels, *filters]
|
||||
|
||||
in_out_pairs = zip(filters[:-1], filters[1:])
|
||||
self.no_const = no_const
|
||||
|
||||
if no_const:
|
||||
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False)
|
||||
else:
|
||||
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4)))
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
self.attns = nn.ModuleList([])
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
|
||||
not_first = ind != 0
|
||||
not_last = ind != (self.num_layers - 1)
|
||||
num_layer = self.num_layers - ind
|
||||
|
||||
attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None
|
||||
|
||||
self.attns.append(attn_fn)
|
||||
|
||||
block_fn = GeneratorBlock
|
||||
|
||||
block = block_fn(
|
||||
latent_dim,
|
||||
in_chan,
|
||||
out_chan,
|
||||
upsample=not_first,
|
||||
upsample_rgb=not_last,
|
||||
rgba=transparent,
|
||||
initial_block=(ind == 0)
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, styles, input_noises):
|
||||
batch_size = styles.shape[0]
|
||||
|
||||
if self.no_const:
|
||||
avg_style = styles.mean(dim=1)[:, :, None, None]
|
||||
x = self.to_initial_block(avg_style)
|
||||
else:
|
||||
x = self.initial_block.expand(batch_size, -1, -1, -1)
|
||||
|
||||
rgb = None
|
||||
styles = styles.transpose(0, 1)
|
||||
|
||||
n = 0
|
||||
for style, block, attn in zip(styles, self.blocks, self.attns):
|
||||
if exists(attn):
|
||||
x = checkpoint(attn, x)
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noises[n], input_noises[n+1])
|
||||
n = 1 if n == 0 else n + 2 # The first block only consumes 1 noise, the rest consume 2.
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
# Wrapper that combines style vectorizer with the actual generator.
|
||||
class StyleGan2GeneratorWithLatent(nn.Module):
|
||||
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False,
|
||||
attn_layers=[], no_const=False, fmap_max=512):
|
||||
super().__init__()
|
||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
||||
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max)
|
||||
self.mixed_prob = .9
|
||||
self._init_weights()
|
||||
|
||||
|
||||
def noise(self, n, latent_dim, device):
|
||||
return torch.randn(n, latent_dim).cuda(device)
|
||||
|
||||
def noise_list(self, n, layers, latent_dim, device):
|
||||
return [(self.noise(n, latent_dim, device), layers)]
|
||||
|
||||
def mixed_list(self, n, layers, latent_dim, device):
|
||||
tt = int(torch.rand(()).numpy() * layers)
|
||||
return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device)
|
||||
|
||||
def latent_to_w(self, style_vectorizer, latent_descr):
|
||||
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
|
||||
|
||||
def styles_def_to_tensor(self, styles_def):
|
||||
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
|
||||
|
||||
# If provided, 'noise' should be a list of tensors that is fed into each input block.
|
||||
# b=batch_size.
|
||||
def forward(self, b, noises=None):
|
||||
if noises is None:
|
||||
noises = [None] * (len(self.gen.blocks) * 2 - 1)
|
||||
full_random_latents = True
|
||||
if full_random_latents:
|
||||
style = self.noise(b*2, self.gen.latent_dim, next(self.parameters()).device)
|
||||
w = self.vectorizer(style)
|
||||
# Randomly distribute styles across layers
|
||||
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
||||
for j in range(b):
|
||||
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
|
||||
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
|
||||
w_styles[j] = w_styles[j*2]
|
||||
else:
|
||||
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
|
||||
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
|
||||
w_styles = w_styles[:b]
|
||||
else:
|
||||
get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list
|
||||
style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device)
|
||||
w_space = self.latent_to_w(self.vectorizer, style)
|
||||
w_styles = self.styles_def_to_tensor(w_space)
|
||||
|
||||
return self.gen(w_styles, noises), w_styles
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Blur(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
|
||||
) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
x = self.net(x)
|
||||
if exists(self.downsample):
|
||||
x = self.downsample(x)
|
||||
x = (x + res) * (1 / math.sqrt(2))
|
||||
return x
|
||||
|
||||
|
||||
class StyleGan2Discriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||
transparent=False, fmap_max=512, input_filters=3):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 1)
|
||||
|
||||
blocks = []
|
||||
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
|
||||
|
||||
set_fmap_max = partial(min, fmap_max)
|
||||
filters = list(map(set_fmap_max, filters))
|
||||
chan_in_out = list(zip(filters[:-1], filters[1:]))
|
||||
|
||||
blocks = []
|
||||
attn_blocks = []
|
||||
quantize_blocks = []
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(chan_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
|
||||
blocks.append(block)
|
||||
|
||||
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
|
||||
|
||||
attn_blocks.append(attn_fn)
|
||||
|
||||
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
|
||||
quantize_blocks.append(quantize_fn)
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.attn_blocks = nn.ModuleList(attn_blocks)
|
||||
self.quantize_blocks = nn.ModuleList(quantize_blocks)
|
||||
|
||||
chan_last = filters[-1]
|
||||
latent_dim = 2 * 2 * chan_last
|
||||
|
||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||
self.flatten = Flatten()
|
||||
self.to_logit = nn.Linear(latent_dim, 1)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
||||
quantize_loss = torch.zeros(1).to(x)
|
||||
|
||||
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
||||
x = block(x)
|
||||
|
||||
if exists(attn_block):
|
||||
x = attn_block(x)
|
||||
|
||||
if exists(q_block):
|
||||
x, _, loss = q_block(x)
|
||||
quantize_loss += loss
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = self.flatten(x)
|
||||
x = self.to_logit(x)
|
||||
if exists(q_block):
|
||||
return x.squeeze(), quantize_loss
|
||||
else:
|
||||
return x.squeeze()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
|
||||
class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return divergence_loss
|
||||
|
||||
|
||||
class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.w_styles = opt['w_styles']
|
||||
self.gen = opt['gen']
|
||||
self.pl_mean = None
|
||||
from models.stylegan.stylegan2_lucidrains import EMA
|
||||
self.pl_length_ma = EMA(.99)
|
||||
|
||||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.stylegan.stylegan2_lucidrains import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
||||
return pl_loss
|
||||
else:
|
||||
print("Path length loss returned NaN!")
|
||||
|
||||
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
||||
return 0
|
|
@ -0,0 +1,286 @@
|
|||
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
|
||||
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
|
||||
# Adapted to lucidrains' Stylegan implementation.
|
||||
#
|
||||
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import math
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import utils
|
||||
|
||||
|
||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||
from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent
|
||||
|
||||
|
||||
def get_vars(vars, source_name):
|
||||
net_name = source_name.split('/')[0]
|
||||
vars_as_tuple_list = vars[net_name]['variables']
|
||||
result_vars = {}
|
||||
for t in vars_as_tuple_list:
|
||||
result_vars[t[0]] = t[1]
|
||||
return result_vars, source_name.replace(net_name + "/", "")
|
||||
|
||||
def get_vars_direct(vars, source_name):
|
||||
v, n = get_vars(vars, source_name)
|
||||
return v[n]
|
||||
|
||||
|
||||
def convert_modconv(vars, source_name, target_name, flip=False, numeral=1):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
noise = vars[source_name + "/noise_strength"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)),
|
||||
f"to_style{numeral}.weight": mod_weight.transpose((1, 0)),
|
||||
f"to_style{numeral}.bias": mod_bias + 1,
|
||||
f"noise{numeral}_scale": np.array([noise]),
|
||||
f"activation{numeral}.bias": bias,
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
if flip:
|
||||
dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip(
|
||||
dic_torch[target_name + f".conv{numeral}.weight"], [2, 3]
|
||||
)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
|
||||
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
||||
|
||||
if bias:
|
||||
dic["bias"] = vars[source_name + "/bias"]
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
||||
|
||||
if bias:
|
||||
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_torgb(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
"conv.weight": weight.transpose((3, 2, 0, 1)),
|
||||
"to_style.weight": mod_weight.transpose((1, 0)),
|
||||
"to_style.bias": mod_bias + 1,
|
||||
# "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this?
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_dense(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def update(state_dict, new, strict=True):
|
||||
|
||||
for k, v in new.items():
|
||||
if strict:
|
||||
if k not in state_dict:
|
||||
raise KeyError(k + " is not found")
|
||||
|
||||
if v.shape != state_dict[k].shape:
|
||||
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
|
||||
|
||||
state_dict[k] = v
|
||||
|
||||
|
||||
def discriminator_fill_statedict(statedict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
||||
|
||||
conv_i = 1
|
||||
|
||||
for i in range(log_size - 2, 0, -1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||
),
|
||||
)
|
||||
conv_i += 1
|
||||
|
||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
||||
|
||||
return statedict
|
||||
|
||||
|
||||
def fill_statedict(state_dict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
for i in range(8):
|
||||
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}"))
|
||||
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
"gen.initial_block": torch.from_numpy(
|
||||
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
for i in range(log_size - 1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
state_dict,
|
||||
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"),
|
||||
)
|
||||
|
||||
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1))
|
||||
|
||||
for i in range(1, log_size - 1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars,
|
||||
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
||||
f"gen.blocks.{i}",
|
||||
#flip=True, # TODO: why??
|
||||
numeral=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2
|
||||
),
|
||||
)
|
||||
|
||||
'''
|
||||
TODO: consider porting this, though I dont think it is necessary.
|
||||
for i in range(0, (log_size - 2) * 2 + 1):
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
f"noises.noise_{i}": torch.from_numpy(
|
||||
get_vars_direct(vars, f"G_synthesis/noise{i}")
|
||||
)
|
||||
},
|
||||
)
|
||||
'''
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tensorflow to pytorch model checkpoint converter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen", action="store_true", help="convert the generator weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel_multiplier",
|
||||
type=int,
|
||||
default=2,
|
||||
help="channel multiplier factor. config-f = 2, else = 1",
|
||||
)
|
||||
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
sys.path.append('scripts\\stylegan2')
|
||||
|
||||
import dnnlib
|
||||
from dnnlib.tflib.network import generator, gen_ema
|
||||
|
||||
with open(args.path, "rb") as f:
|
||||
pickle.load(f)
|
||||
|
||||
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
|
||||
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
|
||||
|
||||
g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8)
|
||||
state_dict = g.state_dict()
|
||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
||||
|
||||
g.load_state_dict(state_dict, strict=True)
|
||||
|
||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
||||
|
||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
||||
|
||||
if args.gen:
|
||||
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||
g_train_state = g_train.state_dict()
|
||||
g_train_state = fill_statedict(g_train_state, generator, size)
|
||||
ckpt["g"] = g_train_state
|
||||
|
||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||
torch.save(ckpt, name + ".pt")
|
||||
|
||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||
n_sample = batch_size.get(size, 25)
|
||||
|
||||
g = g.to(device)
|
||||
|
||||
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
|
||||
|
||||
with torch.no_grad():
|
||||
img_pt, _ = g(8)
|
||||
|
||||
utils.save_image(
|
||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
||||
)
|
Loading…
Reference in New Issue
Block a user