Rosinality stylegan2 port

This commit is contained in:
James Betker 2020-12-17 14:18:46 -07:00
parent 12cf052889
commit e838c6e75b
16 changed files with 1357 additions and 259 deletions

View File

@ -9,7 +9,7 @@ from torchvision import transforms
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
import models.archs.stylegan.stylegan2 as sg2 import models.archs.stylegan.stylegan2_lucidrains as sg2
def convert_transparent_to_rgb(image): def convert_transparent_to_rgb(image):

View File

@ -1,5 +1,4 @@
import models.archs.stylegan.stylegan2 as stylegan2 import models.archs.stylegan.stylegan2_lucidrains as stylegan2
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
def create_stylegan2_loss(opt_loss, env): def create_stylegan2_loss(opt_loss, env):
@ -8,7 +7,5 @@ def create_stylegan2_loss(opt_loss, env):
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env) return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen': elif type == 'stylegan2_pathlen':
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env) return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
elif type == 'stylegan2_unet_divergence':
return stylegan2_unet.StyleGan2UnetDivergenceLoss(opt_loss, env)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -0,0 +1,2 @@
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d

View File

@ -0,0 +1,110 @@
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
import fused_bias_act_cuda
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused_bias_act_cuda.fused_bias_act(
grad_output, empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused_bias_act_cuda.fused_bias_act(
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused_bias_act_cuda.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)

View File

@ -0,0 +1,21 @@
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}

View File

@ -0,0 +1,32 @@
#!/usr/bin/env python3
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
]
setup(
name='stylegan2_ops_cuda',
ext_modules=[
CUDAExtension('fused_bias_act_cuda', [
'fused_bias_act.cpp',
'fused_bias_act_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}),
CUDAExtension('upfirdn2d_cuda', [
'upfirdn2d.cpp',
'upfirdn2d_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -0,0 +1,23 @@
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}

View File

@ -0,0 +1,191 @@
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
import upfirdn2d_cuda
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_cuda.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_cuda.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_cuda.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == "cpu":
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
else:
out = UpFirDn2d.apply(
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)

View File

@ -858,7 +858,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0: if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
gp = gradient_penalty(real_input, real) gp = gradient_penalty(real_input, real)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))
divergence_loss = divergence_loss + gp divergence_loss = divergence_loss + gp
@ -873,17 +873,17 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
self.w_styles = opt['w_styles'] self.w_styles = opt['w_styles']
self.gen = opt['gen'] self.gen = opt['gen']
self.pl_mean = None self.pl_mean = None
from models.archs.stylegan.stylegan2 import EMA from models.archs.stylegan.stylegan2_lucidrains import EMA
self.pl_length_ma = EMA(.99) self.pl_length_ma = EMA(.99)
def forward(self, net, state): def forward(self, net, state):
w_styles = state[self.w_styles] w_styles = state[self.w_styles]
gen = state[self.gen] gen = state[self.gen]
from models.archs.stylegan.stylegan2 import calc_pl_lengths from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen) pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan.stylegan2 import is_empty from models.archs.stylegan.stylegan2_lucidrains import is_empty
if not is_empty(self.pl_mean): if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss): if not torch.isnan(pl_loss):

View File

@ -0,0 +1,660 @@
import math
import random
import functools
import operator
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from models.archs.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out

View File

@ -1,243 +0,0 @@
from functools import partial
from math import log2
from random import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.stylegan.stylegan2 as sg2
import models.steps.losses as L
def leaky_relu(p=0.2):
return nn.LeakyReLU(p)
def double_conv(chan_in, chan_out):
return nn.Sequential(
nn.Conv2d(chan_in, chan_out, 3, padding=1),
leaky_relu(),
nn.Conv2d(chan_out, chan_out, 3, padding=1),
leaky_relu()
)
class Flatten(nn.Module):
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x.flatten(self.index)
class DownBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
self.net = double_conv(input_channels, filters)
self.down = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
unet_res = x
if self.down is not None:
x = self.down(x)
x = x + res
return x, unet_res
class UpBlock(nn.Module):
def __init__(self, input_channels, filters):
super().__init__()
self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride = 2)
self.net = double_conv(input_channels, filters)
self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False)
self.input_channels = input_channels
self.filters = filters
def forward(self, x, res):
*_, h, w = x.shape
conv_res = self.conv_res(x, output_size = (h * 2, w * 2))
x = self.up(x)
x = torch.cat((x, res), dim=1)
x = self.net(x)
x = x + conv_res
return x
class StyleGan2UnetDiscriminator(nn.Module):
def __init__(self, image_size, network_capacity = 16, fmap_max = 512, input_filters=3):
super().__init__()
num_layers = int(log2(image_size) - 3)
blocks = []
filters = [input_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
filters[-1] = filters[-2]
chan_in_out = list(zip(filters[:-1], filters[1:]))
chan_in_out = list(map(list, chan_in_out))
down_blocks = []
attn_blocks = []
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
down_blocks.append(block)
attn_fn = sg2.attn_and_ff(out_chan)
attn_blocks.append(attn_fn)
self.down_blocks = nn.ModuleList(down_blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
last_chan = filters[-1]
self.to_logit = nn.Sequential(
leaky_relu(),
nn.AvgPool2d(image_size // (2 ** num_layers)),
Flatten(1),
nn.Linear(last_chan, 1)
)
self.conv = double_conv(last_chan, last_chan)
dec_chan_in_out = chan_in_out[:-1][::-1]
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
self.conv_out = nn.Conv2d(input_filters, 1, 1)
def forward(self, x):
b, *_ = x.shape
residuals = []
for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
x, unet_res = down_block(x)
residuals.append(unet_res)
if attn_block is not None:
x = attn_block(x)
x = self.conv(x) + x
enc_out = self.to_logit(x)
for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
x = up_block(x, res)
dec_out = self.conv_out(x)
return dec_out, enc_out
def warmup(start, end, max_steps, current_step):
if current_step > max_steps:
return end
return (end - start) * (current_step / max_steps) + start
def mask_src_tgt(source, target, mask):
return source * mask + (1 - mask) * target
def cutmix(source, target, coors, alpha = 1.):
source, target = map(torch.clone, (source, target))
((y0, y1), (x0, x1)), _ = coors
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
return source
def cutmix_coordinates(height, width, alpha = 1.):
lam = np.random.beta(alpha, alpha)
cx = np.random.uniform(0, width)
cy = np.random.uniform(0, height)
w = width * np.sqrt(1 - lam)
h = height * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, width)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, height)))
return ((y0, y1), (x0, x1)), lam
class StyleGan2UnetDivergenceLoss(L.ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.image_size = opt['image_size']
self.cr_weight = .2
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake_dec, fake_enc = D(fake_input)
fake_aug_images = D.module.aug_images
if self.for_gen:
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
else:
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
apply_cutmix = random() < cutmix_prob
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real_dec, real_enc = D(real_input)
real_aug_images = D.module.aug_images
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
disc_loss = enc_divergence + dec_divergence * dec_loss_coef
if apply_cutmix:
mask = cutmix(
torch.ones_like(real_dec),
torch.zeros_like(real_dec),
cutmix_coordinates(self.image_size, self.image_size)
)
if random() > 0.5:
mask = 1 - mask
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
cutmix_dec_out, cutmix_enc_out = D.module.D(cutmix_images) # Bypass implied augmentor - hence D.module.D
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
disc_loss = disc_loss + cutmix_enc_divergence + cutmix_dec_divergence
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
self.last_cr_loss = cr_loss.clone().detach().item()
disc_loss = disc_loss + cr_loss * dec_loss_coef
# Apply gradient penalty. TODO: migrate this elsewhere.
if self.env['step'] % self.gp_frequency == 0:
from models.archs.stylegan.stylegan2 import gradient_penalty
if random() < .5:
gp = gradient_penalty(real_input, real_enc)
else:
gp = gradient_penalty(real_input, real_dec) * dec_loss_coef
self.metrics.append(("gradient_penalty", gp.clone().detach()))
disc_loss = disc_loss + gp
real_input.requires_grad_(requires_grad=False)
return disc_loss

View File

@ -6,8 +6,7 @@ import munch
import torch import torch
import torchvision import torchvision
from munch import munchify from munch import munchify
import models.archs.stylegan.stylegan2 as stylegan2 import models.archs.stylegan.stylegan2_lucidrains as stylegan2
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.RRDBNet_arch as RRDBNet_arch
@ -228,9 +227,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "stylegan2_unet":
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "rrdb_disc": elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3) netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
else: else:

View File

@ -312,7 +312,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if self.gradient_penalty: if self.gradient_penalty:
# Apply gradient penalty. TODO: migrate this elsewhere. # Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan.stylegan2 import gradient_penalty from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp = gradient_penalty(real[0], d_real) gp = gradient_penalty(real[0], d_real)
self.metrics.append(("gradient_penalty", gp.clone().detach())) self.metrics.append(("gradient_penalty", gp.clone().detach()))

View File

@ -1,6 +1,6 @@
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2 import gradient_penalty from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector from models.steps.injectors import Injector

View File

@ -0,0 +1,292 @@
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
# Adapted to lucidrains' Stylegan implementation.
#
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
import argparse
import os
import sys
import pickle
import math
import torch
import numpy as np
from torchvision import utils
from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
def get_vars(vars, source_name):
net_name = source_name.split('/')[0]
vars_as_tuple_list = vars[net_name]['variables']
result_vars = {}
for t in vars_as_tuple_list:
result_vars[t[0]] = t[1]
return result_vars, source_name.replace(net_name + "/", "")
def get_vars_direct(vars, source_name):
v, n = get_vars(vars, source_name)
return v[n]
def convert_modconv(vars, source_name, target_name, flip=False):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
noise = vars[source_name + "/noise_strength"]
bias = vars[source_name + "/bias"]
dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"noise.weight": np.array([noise]),
"activate.bias": bias,
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
if flip:
dic_torch[target_name + ".conv.weight"] = torch.flip(
dic_torch[target_name + ".conv.weight"], [3, 4]
)
return dic_torch
def convert_conv(vars, source_name, target_name, bias=True, start=0):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
dic = {"weight": weight.transpose((3, 2, 0, 1))}
if bias:
dic["bias"] = vars[source_name + "/bias"]
dic_torch = {}
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
if bias:
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
return dic_torch
def convert_torgb(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
bias = vars[source_name + "/bias"]
dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"bias": bias.reshape((1, 3, 1, 1)),
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def convert_dense(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
bias = vars[source_name + "/bias"]
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def update(state_dict, new):
for k, v in new.items():
state_dict[k] = v
def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
conv_i = 1
for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
return statedict
def fill_statedict(state_dict, vars, size):
log_size = int(math.log(size, 2))
for i in range(8):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
update(
state_dict,
{
"input.input": torch.from_numpy(
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
)
},
)
update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
)
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
conv_i = 0
for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_modconv(
vars,
f"G_synthesis/{reso}x{reso}/Conv0_up",
f"convs.{conv_i}",
flip=True,
),
)
update(
state_dict,
convert_modconv(
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
),
)
conv_i += 2
for i in range(0, (log_size - 2) * 2 + 1):
update(
state_dict,
{
f"noises.noise_{i}": torch.from_numpy(
get_vars_direct(vars, f"G_synthesis/noise{i}")
)
},
)
return state_dict
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--disc", action="store_true", help="convert the discriminator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
args = parser.parse_args()
sys.path.append('scripts\\stylegan2')
import dnnlib
from dnnlib.tflib.network import generator, discriminator, gen_ema
with open(args.path, "rb") as f:
pickle.load(f)
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, gen_ema, size)
g.load_state_dict(state_dict, strict=True)
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
if args.gen:
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator, size)
ckpt["g"] = g_train_state
if args.disc:
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
d_state = disc.state_dict()
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
ckpt["d"] = d_state
name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(ckpt, name + ".pt")
batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = batch_size.get(size, 25)
g = g.to(device)
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(
[torch.from_numpy(z).to(device)],
truncation=0.5,
truncation_latent=latent_avg.to(device),
randomize_noise=False,
)
utils.save_image(
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)

View File

@ -0,0 +1,17 @@
# Pretends to be the stylegan2 Network class for intercepting pickle load requests.
# Horrible hack. Please don't judge me.
# Globals for storing these networks because I have no idea how pickle is doing this internally.
generator, discriminator, gen_ema = {}, {}, {}
class Network:
def __setstate__(self, state: dict) -> None:
global generator, discriminator, gen_ema
name = state['name']
if name in ['G_synthesis', 'G_mapping', 'G', 'G_main']:
if name != 'G' and name not in generator.keys():
generator[name] = state
else:
gen_ema[name] = state
elif name in ['D']:
discriminator[name] = state