Rosinality stylegan2 port
This commit is contained in:
parent
12cf052889
commit
e838c6e75b
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
|||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
import models.archs.stylegan.stylegan2 as sg2
|
||||
import models.archs.stylegan.stylegan2_lucidrains as sg2
|
||||
|
||||
|
||||
def convert_transparent_to_rgb(image):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import models.archs.stylegan.stylegan2 as stylegan2
|
||||
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
|
||||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
|
||||
def create_stylegan2_loss(opt_loss, env):
|
||||
|
@ -8,7 +7,5 @@ def create_stylegan2_loss(opt_loss, env):
|
|||
return stylegan2.StyleGan2DivergenceLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_pathlen':
|
||||
return stylegan2.StyleGan2PathLengthLoss(opt_loss, env)
|
||||
elif type == 'stylegan2_unet_divergence':
|
||||
return stylegan2_unet.StyleGan2UnetDivergenceLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
2
codes/models/archs/stylegan/op/__init__.py
Normal file
2
codes/models/archs/stylegan/op/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from .upfirdn2d import upfirdn2d
|
110
codes/models/archs/stylegan/op/fused_act.py
Normal file
110
codes/models/archs/stylegan/op/fused_act.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import load
|
||||
import fused_bias_act_cuda
|
||||
|
||||
|
||||
class FusedLeakyReLUFunctionBackward(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
empty = grad_output.new_empty(0)
|
||||
|
||||
grad_input = fused_bias_act_cuda.fused_bias_act(
|
||||
grad_output, empty, out, 3, 1, negative_slope, scale
|
||||
)
|
||||
|
||||
dim = [0]
|
||||
|
||||
if grad_input.ndim > 2:
|
||||
dim += list(range(2, grad_input.ndim))
|
||||
|
||||
if bias:
|
||||
grad_bias = grad_input.sum(dim).detach()
|
||||
|
||||
else:
|
||||
grad_bias = empty
|
||||
|
||||
return grad_input, grad_bias
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
||||
out, = ctx.saved_tensors
|
||||
gradgrad_out = fused_bias_act_cuda.fused_bias_act(
|
||||
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
||||
)
|
||||
|
||||
return gradgrad_out, None, None, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLUFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias, negative_slope, scale):
|
||||
empty = input.new_empty(0)
|
||||
|
||||
ctx.bias = bias is not None
|
||||
|
||||
if bias is None:
|
||||
bias = empty
|
||||
|
||||
out = fused_bias_act_cuda.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
out, = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
||||
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
||||
)
|
||||
|
||||
if not ctx.bias:
|
||||
grad_bias = None
|
||||
|
||||
return grad_input, grad_bias, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||
if input.device.type == "cpu":
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
else:
|
||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
21
codes/models/archs/stylegan/op/fused_bias_act.cpp
Normal file
21
codes/models/archs/stylegan/op/fused_bias_act.cpp
Normal file
|
@ -0,0 +1,21 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(bias);
|
||||
|
||||
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
||||
}
|
32
codes/models/archs/stylegan/op/setup.py
Normal file
32
codes/models/archs/stylegan/op/setup.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import torch
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
cxx_args = ['-std=c++11']
|
||||
|
||||
nvcc_args = [
|
||||
'-gencode', 'arch=compute_50,code=sm_50',
|
||||
'-gencode', 'arch=compute_52,code=sm_52',
|
||||
'-gencode', 'arch=compute_60,code=sm_60',
|
||||
'-gencode', 'arch=compute_61,code=sm_61',
|
||||
'-gencode', 'arch=compute_70,code=sm_70',
|
||||
]
|
||||
|
||||
setup(
|
||||
name='stylegan2_ops_cuda',
|
||||
ext_modules=[
|
||||
CUDAExtension('fused_bias_act_cuda', [
|
||||
'fused_bias_act.cpp',
|
||||
'fused_bias_act_kernel.cu'
|
||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}),
|
||||
CUDAExtension('upfirdn2d_cuda', [
|
||||
'upfirdn2d.cpp',
|
||||
'upfirdn2d_kernel.cu'
|
||||
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
23
codes/models/archs/stylegan/op/upfirdn2d.cpp
Normal file
23
codes/models/archs/stylegan/op/upfirdn2d.cpp
Normal file
|
@ -0,0 +1,23 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(kernel);
|
||||
|
||||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
||||
}
|
191
codes/models/archs/stylegan/op/upfirdn2d.py
Normal file
191
codes/models/archs/stylegan/op/upfirdn2d.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import load
|
||||
import upfirdn2d_cuda
|
||||
|
||||
|
||||
class UpFirDn2dBackward(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
||||
):
|
||||
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
||||
|
||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
||||
|
||||
grad_input = upfirdn2d_cuda.upfirdn2d(
|
||||
grad_output,
|
||||
grad_kernel,
|
||||
down_x,
|
||||
down_y,
|
||||
up_x,
|
||||
up_y,
|
||||
g_pad_x0,
|
||||
g_pad_x1,
|
||||
g_pad_y0,
|
||||
g_pad_y1,
|
||||
)
|
||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
||||
|
||||
ctx.save_for_backward(kernel)
|
||||
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
ctx.up_x = up_x
|
||||
ctx.up_y = up_y
|
||||
ctx.down_x = down_x
|
||||
ctx.down_y = down_y
|
||||
ctx.pad_x0 = pad_x0
|
||||
ctx.pad_x1 = pad_x1
|
||||
ctx.pad_y0 = pad_y0
|
||||
ctx.pad_y1 = pad_y1
|
||||
ctx.in_size = in_size
|
||||
ctx.out_size = out_size
|
||||
|
||||
return grad_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input):
|
||||
kernel, = ctx.saved_tensors
|
||||
|
||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
||||
|
||||
gradgrad_out = upfirdn2d_cuda.upfirdn2d(
|
||||
gradgrad_input,
|
||||
kernel,
|
||||
ctx.up_x,
|
||||
ctx.up_y,
|
||||
ctx.down_x,
|
||||
ctx.down_y,
|
||||
ctx.pad_x0,
|
||||
ctx.pad_x1,
|
||||
ctx.pad_y0,
|
||||
ctx.pad_y1,
|
||||
)
|
||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
||||
gradgrad_out = gradgrad_out.view(
|
||||
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
||||
)
|
||||
|
||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class UpFirDn2d(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel, up, down, pad):
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
batch, channel, in_h, in_w = input.shape
|
||||
ctx.in_size = input.shape
|
||||
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
ctx.out_size = (out_h, out_w)
|
||||
|
||||
ctx.up = (up_x, up_y)
|
||||
ctx.down = (down_x, down_y)
|
||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
|
||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
||||
|
||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
||||
|
||||
out = upfirdn2d_cuda.upfirdn2d(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
)
|
||||
# out = out.view(major, out_h, out_w, minor)
|
||||
out = out.view(-1, channel, out_h, out_w)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel, grad_kernel = ctx.saved_tensors
|
||||
|
||||
grad_input = UpFirDn2dBackward.apply(
|
||||
grad_output,
|
||||
kernel,
|
||||
grad_kernel,
|
||||
ctx.up,
|
||||
ctx.down,
|
||||
ctx.pad,
|
||||
ctx.g_pad,
|
||||
ctx.in_size,
|
||||
ctx.out_size,
|
||||
)
|
||||
|
||||
return grad_input, None, None, None, None
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
if input.device.type == "cpu":
|
||||
out = upfirdn2d_native(
|
||||
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
||||
)
|
||||
|
||||
else:
|
||||
out = UpFirDn2d.apply(
|
||||
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
||||
)
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
||||
)
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
|
@ -858,7 +858,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
|||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
gp = gradient_penalty(real_input, real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
divergence_loss = divergence_loss + gp
|
||||
|
@ -873,17 +873,17 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss):
|
|||
self.w_styles = opt['w_styles']
|
||||
self.gen = opt['gen']
|
||||
self.pl_mean = None
|
||||
from models.archs.stylegan.stylegan2 import EMA
|
||||
from models.archs.stylegan.stylegan2_lucidrains import EMA
|
||||
self.pl_length_ma = EMA(.99)
|
||||
|
||||
def forward(self, net, state):
|
||||
w_styles = state[self.w_styles]
|
||||
gen = state[self.gen]
|
||||
from models.archs.stylegan.stylegan2 import calc_pl_lengths
|
||||
from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths
|
||||
pl_lengths = calc_pl_lengths(w_styles, gen)
|
||||
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
||||
|
||||
from models.archs.stylegan.stylegan2 import is_empty
|
||||
from models.archs.stylegan.stylegan2_lucidrains import is_empty
|
||||
if not is_empty(self.pl_mean):
|
||||
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
||||
if not torch.isnan(pl_loss):
|
660
codes/models/archs/stylegan/stylegan2_rosinality.py
Normal file
660
codes/models/archs/stylegan/stylegan2_rosinality.py
Normal file
|
@ -0,0 +1,660 @@
|
|||
import math
|
||||
import random
|
||||
import functools
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from models.archs.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor ** 2)
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel)
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer("kernel", kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
out = F.conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
||||
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
||||
)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
demodulate=True,
|
||||
upsample=False,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-8
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.upsample = upsample
|
||||
self.downsample = downsample
|
||||
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
fan_in = in_channel * kernel_size ** 2
|
||||
self.scale = 1 / math.sqrt(fan_in)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
|
||||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||
|
||||
self.demodulate = demodulate
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
||||
f"upsample={self.upsample}, downsample={self.downsample})"
|
||||
)
|
||||
|
||||
def forward(self, input, style):
|
||||
batch, in_channel, height, width = input.shape
|
||||
|
||||
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
||||
weight = self.scale * self.weight * style
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
||||
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
|
||||
if self.upsample:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
weight = weight.view(
|
||||
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
weight = weight.transpose(1, 2).reshape(
|
||||
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
out = self.blur(out)
|
||||
|
||||
elif self.downsample:
|
||||
input = self.blur(input)
|
||||
_, _, height, width = input.shape
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
else:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, image, noise=None):
|
||||
if noise is None:
|
||||
batch, _, height, width = image.shape
|
||||
noise = image.new_empty(batch, 1, height, width).normal_()
|
||||
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
def __init__(self, channel, size=4):
|
||||
super().__init__()
|
||||
|
||||
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
||||
|
||||
def forward(self, input):
|
||||
batch = input.shape[0]
|
||||
out = self.input.repeat(batch, 1, 1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StyledConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
demodulate=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = ModulatedConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=upsample,
|
||||
blur_kernel=blur_kernel,
|
||||
demodulate=demodulate,
|
||||
)
|
||||
|
||||
self.noise = NoiseInjection()
|
||||
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
||||
# self.activate = ScaledLeakyReLU(0.2)
|
||||
self.activate = FusedLeakyReLU(out_channel)
|
||||
|
||||
def forward(self, input, style, noise=None):
|
||||
out = self.conv(input, style)
|
||||
out = self.noise(out, noise=noise)
|
||||
# out = out + self.bias
|
||||
out = self.activate(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
|
||||
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, input, style, skip=None):
|
||||
out = self.conv(input, style)
|
||||
out = out + self.bias
|
||||
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
|
||||
out = out + skip
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.style_dim = style_dim
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
for i in range(n_mlp):
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
||||
)
|
||||
)
|
||||
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
self.conv1 = StyledConv(
|
||||
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channel = self.channels[4]
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
res = (layer_idx + 5) // 2
|
||||
shape = [1, 1, 2 ** res, 2 ** res]
|
||||
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channel = self.channels[2 ** i]
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
)
|
||||
)
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
)
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.n_latent = self.log_size * 2 - 2
|
||||
|
||||
def make_noise(self):
|
||||
device = self.input.input.device
|
||||
|
||||
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def mean_latent(self, n_latent):
|
||||
latent_in = torch.randn(
|
||||
n_latent, self.style_dim, device=self.input.input.device
|
||||
)
|
||||
latent = self.style(latent_in).mean(0, keepdim=True)
|
||||
|
||||
return latent
|
||||
|
||||
def get_latent(self, input):
|
||||
return self.style(input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
):
|
||||
if not input_is_latent:
|
||||
styles = [self.style(s) for s in styles]
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers
|
||||
else:
|
||||
noise = [
|
||||
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
||||
]
|
||||
|
||||
if truncation < 1:
|
||||
style_t = []
|
||||
|
||||
for style in styles:
|
||||
style_t.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
|
||||
styles = style_t
|
||||
|
||||
if len(styles) < 2:
|
||||
inject_index = self.n_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
|
||||
else:
|
||||
latent = styles[0]
|
||||
|
||||
else:
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.n_latent - 1)
|
||||
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
||||
|
||||
latent = torch.cat([latent, latent2], 1)
|
||||
|
||||
out = self.input(latent)
|
||||
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
||||
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
||||
):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
)
|
||||
|
||||
if activate:
|
||||
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(
|
||||
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
convs = [ConvLayer(3, channels[size], 1)]
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
in_channel = channels[size]
|
||||
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
|
||||
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
self.stddev_group = 4
|
||||
self.stddev_feat = 1
|
||||
|
||||
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
||||
self.final_linear = nn.Sequential(
|
||||
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
||||
EqualLinear(channels[4], 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.convs(input)
|
||||
|
||||
batch, channel, height, width = out.shape
|
||||
group = min(batch, self.stddev_group)
|
||||
stddev = out.view(
|
||||
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
||||
)
|
||||
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
||||
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
||||
stddev = stddev.repeat(group, 1, height, width)
|
||||
out = torch.cat([out, stddev], 1)
|
||||
|
||||
out = self.final_conv(out)
|
||||
|
||||
out = out.view(batch, -1)
|
||||
out = self.final_linear(out)
|
||||
|
||||
return out
|
|
@ -1,243 +0,0 @@
|
|||
from functools import partial
|
||||
from math import log2
|
||||
from random import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import models.archs.stylegan.stylegan2 as sg2
|
||||
import models.steps.losses as L
|
||||
|
||||
|
||||
def leaky_relu(p=0.2):
|
||||
return nn.LeakyReLU(p)
|
||||
|
||||
|
||||
def double_conv(chan_in, chan_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(chan_in, chan_out, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan_out, chan_out, 3, padding=1),
|
||||
leaky_relu()
|
||||
)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
def forward(self, x):
|
||||
return x.flatten(self.index)
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
||||
|
||||
self.net = double_conv(input_channels, filters)
|
||||
self.down = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
x = self.net(x)
|
||||
unet_res = x
|
||||
|
||||
if self.down is not None:
|
||||
x = self.down(x)
|
||||
|
||||
x = x + res
|
||||
return x, unet_res
|
||||
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters):
|
||||
super().__init__()
|
||||
self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride = 2)
|
||||
self.net = double_conv(input_channels, filters)
|
||||
self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False)
|
||||
self.input_channels = input_channels
|
||||
self.filters = filters
|
||||
|
||||
def forward(self, x, res):
|
||||
*_, h, w = x.shape
|
||||
conv_res = self.conv_res(x, output_size = (h * 2, w * 2))
|
||||
x = self.up(x)
|
||||
x = torch.cat((x, res), dim=1)
|
||||
x = self.net(x)
|
||||
x = x + conv_res
|
||||
return x
|
||||
|
||||
|
||||
class StyleGan2UnetDiscriminator(nn.Module):
|
||||
def __init__(self, image_size, network_capacity = 16, fmap_max = 512, input_filters=3):
|
||||
super().__init__()
|
||||
num_layers = int(log2(image_size) - 3)
|
||||
|
||||
blocks = []
|
||||
filters = [input_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
|
||||
|
||||
set_fmap_max = partial(min, fmap_max)
|
||||
filters = list(map(set_fmap_max, filters))
|
||||
filters[-1] = filters[-2]
|
||||
|
||||
chan_in_out = list(zip(filters[:-1], filters[1:]))
|
||||
chan_in_out = list(map(list, chan_in_out))
|
||||
|
||||
down_blocks = []
|
||||
attn_blocks = []
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(chan_in_out) - 1)
|
||||
|
||||
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
|
||||
down_blocks.append(block)
|
||||
|
||||
attn_fn = sg2.attn_and_ff(out_chan)
|
||||
attn_blocks.append(attn_fn)
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
self.attn_blocks = nn.ModuleList(attn_blocks)
|
||||
|
||||
last_chan = filters[-1]
|
||||
|
||||
self.to_logit = nn.Sequential(
|
||||
leaky_relu(),
|
||||
nn.AvgPool2d(image_size // (2 ** num_layers)),
|
||||
Flatten(1),
|
||||
nn.Linear(last_chan, 1)
|
||||
)
|
||||
|
||||
self.conv = double_conv(last_chan, last_chan)
|
||||
|
||||
dec_chan_in_out = chan_in_out[:-1][::-1]
|
||||
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
|
||||
self.conv_out = nn.Conv2d(input_filters, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, *_ = x.shape
|
||||
|
||||
residuals = []
|
||||
|
||||
for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
|
||||
x, unet_res = down_block(x)
|
||||
residuals.append(unet_res)
|
||||
|
||||
if attn_block is not None:
|
||||
x = attn_block(x)
|
||||
|
||||
x = self.conv(x) + x
|
||||
enc_out = self.to_logit(x)
|
||||
|
||||
for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
|
||||
x = up_block(x, res)
|
||||
|
||||
dec_out = self.conv_out(x)
|
||||
return dec_out, enc_out
|
||||
|
||||
|
||||
def warmup(start, end, max_steps, current_step):
|
||||
if current_step > max_steps:
|
||||
return end
|
||||
return (end - start) * (current_step / max_steps) + start
|
||||
|
||||
|
||||
def mask_src_tgt(source, target, mask):
|
||||
return source * mask + (1 - mask) * target
|
||||
|
||||
|
||||
def cutmix(source, target, coors, alpha = 1.):
|
||||
source, target = map(torch.clone, (source, target))
|
||||
((y0, y1), (x0, x1)), _ = coors
|
||||
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
|
||||
return source
|
||||
|
||||
|
||||
def cutmix_coordinates(height, width, alpha = 1.):
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
|
||||
cx = np.random.uniform(0, width)
|
||||
cy = np.random.uniform(0, height)
|
||||
w = width * np.sqrt(1 - lam)
|
||||
h = height * np.sqrt(1 - lam)
|
||||
x0 = int(np.round(max(cx - w / 2, 0)))
|
||||
x1 = int(np.round(min(cx + w / 2, width)))
|
||||
y0 = int(np.round(max(cy - h / 2, 0)))
|
||||
y1 = int(np.round(min(cy + h / 2, height)))
|
||||
|
||||
return ((y0, y1), (x0, x1)), lam
|
||||
|
||||
|
||||
class StyleGan2UnetDivergenceLoss(L.ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.real = opt['real']
|
||||
self.fake = opt['fake']
|
||||
self.discriminator = opt['discriminator']
|
||||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
self.image_size = opt['image_size']
|
||||
self.cr_weight = .2
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
fake_input = state[self.fake]
|
||||
if self.noise != 0:
|
||||
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||
|
||||
D = self.env['discriminators'][self.discriminator]
|
||||
fake_dec, fake_enc = D(fake_input)
|
||||
fake_aug_images = D.module.aug_images
|
||||
if self.for_gen:
|
||||
return fake_enc.mean() + F.relu(1 + fake_dec).mean()
|
||||
else:
|
||||
dec_loss_coef = warmup(0, 1., 30000, self.env['step'])
|
||||
cutmix_prob = warmup(0, 0.25, 30000, self.env['step'])
|
||||
apply_cutmix = random() < cutmix_prob
|
||||
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real_dec, real_enc = D(real_input)
|
||||
real_aug_images = D.module.aug_images
|
||||
enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean()
|
||||
dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean()
|
||||
disc_loss = enc_divergence + dec_divergence * dec_loss_coef
|
||||
|
||||
if apply_cutmix:
|
||||
mask = cutmix(
|
||||
torch.ones_like(real_dec),
|
||||
torch.zeros_like(real_dec),
|
||||
cutmix_coordinates(self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
if random() > 0.5:
|
||||
mask = 1 - mask
|
||||
|
||||
cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask)
|
||||
cutmix_dec_out, cutmix_enc_out = D.module.D(cutmix_images) # Bypass implied augmentor - hence D.module.D
|
||||
|
||||
cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean()
|
||||
cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean()
|
||||
disc_loss = disc_loss + cutmix_enc_divergence + cutmix_dec_divergence
|
||||
|
||||
cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask)
|
||||
cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight
|
||||
self.last_cr_loss = cr_loss.clone().detach().item()
|
||||
|
||||
disc_loss = disc_loss + cr_loss * dec_loss_coef
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
if random() < .5:
|
||||
gp = gradient_penalty(real_input, real_enc)
|
||||
else:
|
||||
gp = gradient_penalty(real_input, real_dec) * dec_loss_coef
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
disc_loss = disc_loss + gp
|
||||
|
||||
real_input.requires_grad_(requires_grad=False)
|
||||
return disc_loss
|
|
@ -6,8 +6,7 @@ import munch
|
|||
import torch
|
||||
import torchvision
|
||||
from munch import munchify
|
||||
import models.archs.stylegan.stylegan2 as stylegan2
|
||||
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
|
||||
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
|
||||
|
||||
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||
|
@ -228,9 +227,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
||||
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
elif which_model == "stylegan2_unet":
|
||||
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
||||
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||
elif which_model == "rrdb_disc":
|
||||
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
|
||||
else:
|
||||
|
|
|
@ -312,7 +312,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
|
||||
if self.gradient_penalty:
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||
gp = gradient_penalty(real[0], d_real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty
|
||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.injectors import Injector
|
||||
|
|
292
codes/scripts/stylegan2/convert_weights.py
Normal file
292
codes/scripts/stylegan2/convert_weights.py
Normal file
|
@ -0,0 +1,292 @@
|
|||
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
|
||||
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
|
||||
# Adapted to lucidrains' Stylegan implementation.
|
||||
#
|
||||
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import math
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import utils
|
||||
|
||||
from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator
|
||||
|
||||
|
||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||
def get_vars(vars, source_name):
|
||||
net_name = source_name.split('/')[0]
|
||||
vars_as_tuple_list = vars[net_name]['variables']
|
||||
result_vars = {}
|
||||
for t in vars_as_tuple_list:
|
||||
result_vars[t[0]] = t[1]
|
||||
return result_vars, source_name.replace(net_name + "/", "")
|
||||
|
||||
def get_vars_direct(vars, source_name):
|
||||
v, n = get_vars(vars, source_name)
|
||||
return v[n]
|
||||
|
||||
|
||||
def convert_modconv(vars, source_name, target_name, flip=False):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
noise = vars[source_name + "/noise_strength"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
||||
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
||||
"conv.modulation.bias": mod_bias + 1,
|
||||
"noise.weight": np.array([noise]),
|
||||
"activate.bias": bias,
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
if flip:
|
||||
dic_torch[target_name + ".conv.weight"] = torch.flip(
|
||||
dic_torch[target_name + ".conv.weight"], [3, 4]
|
||||
)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
|
||||
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
||||
|
||||
if bias:
|
||||
dic["bias"] = vars[source_name + "/bias"]
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
||||
|
||||
if bias:
|
||||
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_torgb(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
||||
"conv.modulation.weight": mod_weight.transpose((1, 0)),
|
||||
"conv.modulation.bias": mod_bias + 1,
|
||||
"bias": bias.reshape((1, 3, 1, 1)),
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_dense(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def update(state_dict, new):
|
||||
for k, v in new.items():
|
||||
state_dict[k] = v
|
||||
|
||||
|
||||
def discriminator_fill_statedict(statedict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
||||
|
||||
conv_i = 1
|
||||
|
||||
for i in range(log_size - 2, 0, -1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||
),
|
||||
)
|
||||
conv_i += 1
|
||||
|
||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
||||
|
||||
return statedict
|
||||
|
||||
|
||||
def fill_statedict(state_dict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
for i in range(8):
|
||||
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
|
||||
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
"input.input": torch.from_numpy(
|
||||
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
|
||||
|
||||
for i in range(log_size - 2):
|
||||
reso = 4 * 2 ** (i + 1)
|
||||
update(
|
||||
state_dict,
|
||||
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
|
||||
)
|
||||
|
||||
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
|
||||
|
||||
conv_i = 0
|
||||
|
||||
for i in range(log_size - 2):
|
||||
reso = 4 * 2 ** (i + 1)
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars,
|
||||
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
||||
f"convs.{conv_i}",
|
||||
flip=True,
|
||||
),
|
||||
)
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
|
||||
),
|
||||
)
|
||||
conv_i += 2
|
||||
|
||||
for i in range(0, (log_size - 2) * 2 + 1):
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
f"noises.noise_{i}": torch.from_numpy(
|
||||
get_vars_direct(vars, f"G_synthesis/noise{i}")
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tensorflow to pytorch model checkpoint converter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen", action="store_true", help="convert the generator weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disc", action="store_true", help="convert the discriminator weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel_multiplier",
|
||||
type=int,
|
||||
default=2,
|
||||
help="channel multiplier factor. config-f = 2, else = 1",
|
||||
)
|
||||
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
sys.path.append('scripts\\stylegan2')
|
||||
|
||||
import dnnlib
|
||||
from dnnlib.tflib.network import generator, discriminator, gen_ema
|
||||
|
||||
with open(args.path, "rb") as f:
|
||||
pickle.load(f)
|
||||
|
||||
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
|
||||
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
|
||||
|
||||
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||
state_dict = g.state_dict()
|
||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
||||
|
||||
g.load_state_dict(state_dict, strict=True)
|
||||
|
||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
||||
|
||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
||||
|
||||
if args.gen:
|
||||
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||
g_train_state = g_train.state_dict()
|
||||
g_train_state = fill_statedict(g_train_state, generator, size)
|
||||
ckpt["g"] = g_train_state
|
||||
|
||||
if args.disc:
|
||||
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
|
||||
d_state = disc.state_dict()
|
||||
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
|
||||
ckpt["d"] = d_state
|
||||
|
||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||
torch.save(ckpt, name + ".pt")
|
||||
|
||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||
n_sample = batch_size.get(size, 25)
|
||||
|
||||
g = g.to(device)
|
||||
|
||||
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
|
||||
|
||||
with torch.no_grad():
|
||||
img_pt, _ = g(
|
||||
[torch.from_numpy(z).to(device)],
|
||||
truncation=0.5,
|
||||
truncation_latent=latent_avg.to(device),
|
||||
randomize_noise=False,
|
||||
)
|
||||
|
||||
utils.save_image(
|
||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
||||
)
|
17
codes/scripts/stylegan2/dnnlib/tflib/network.py
Normal file
17
codes/scripts/stylegan2/dnnlib/tflib/network.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Pretends to be the stylegan2 Network class for intercepting pickle load requests.
|
||||
# Horrible hack. Please don't judge me.
|
||||
|
||||
# Globals for storing these networks because I have no idea how pickle is doing this internally.
|
||||
generator, discriminator, gen_ema = {}, {}, {}
|
||||
|
||||
class Network:
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
global generator, discriminator, gen_ema
|
||||
name = state['name']
|
||||
if name in ['G_synthesis', 'G_mapping', 'G', 'G_main']:
|
||||
if name != 'G' and name not in generator.keys():
|
||||
generator[name] = state
|
||||
else:
|
||||
gen_ema[name] = state
|
||||
elif name in ['D']:
|
||||
discriminator[name] = state
|
Loading…
Reference in New Issue
Block a user