Get rosinality model converter to work

Mostly, just needed to remove the custom cuda ops, not so bueno on Windows.
This commit is contained in:
James Betker 2020-12-17 16:03:39 -07:00
parent e838c6e75b
commit 3074f41877
9 changed files with 85 additions and 382 deletions

View File

@ -1,2 +0,0 @@
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d

View File

@ -1,110 +0,0 @@
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
import fused_bias_act_cuda
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused_bias_act_cuda.fused_bias_act(
grad_output, empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused_bias_act_cuda.fused_bias_act(
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused_bias_act_cuda.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)

View File

@ -1,21 +0,0 @@
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}

View File

@ -1,32 +0,0 @@
#!/usr/bin/env python3
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cxx_args = ['-std=c++11']
nvcc_args = [
'-gencode', 'arch=compute_50,code=sm_50',
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
]
setup(
name='stylegan2_ops_cuda',
ext_modules=[
CUDAExtension('fused_bias_act_cuda', [
'fused_bias_act.cpp',
'fused_bias_act_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}),
CUDAExtension('upfirdn2d_cuda', [
'upfirdn2d.cpp',
'upfirdn2d_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -1,23 +0,0 @@
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}

View File

@ -1,191 +0,0 @@
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
import upfirdn2d_cuda
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_cuda.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_cuda.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_cuda.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == "cpu":
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
else:
out = UpFirDn2d.apply(
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)

View File

@ -8,7 +8,90 @@ from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from models.archs.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the
# "cpu" alternative code and uses that instead for compatibility reasons.
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
# /end Ops
class PixelNorm(nn.Module):

View File

@ -1,6 +1,5 @@
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
# Adapted to lucidrains' Stylegan implementation.
#
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
@ -277,7 +276,7 @@ if __name__ == "__main__":
g = g.to(device)
z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(