diff --git a/codes/models/archs/stylegan/op/__init__.py b/codes/models/archs/stylegan/op/__init__.py deleted file mode 100644 index d0918d92..00000000 --- a/codes/models/archs/stylegan/op/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .fused_act import FusedLeakyReLU, fused_leaky_relu -from .upfirdn2d import upfirdn2d diff --git a/codes/models/archs/stylegan/op/fused_act.py b/codes/models/archs/stylegan/op/fused_act.py deleted file mode 100644 index 5d9bd348..00000000 --- a/codes/models/archs/stylegan/op/fused_act.py +++ /dev/null @@ -1,110 +0,0 @@ -import os - -import torch -from torch import nn -from torch.nn import functional as F -from torch.autograd import Function -from torch.utils.cpp_extension import load -import fused_bias_act_cuda - - -class FusedLeakyReLUFunctionBackward(Function): - @staticmethod - def forward(ctx, grad_output, out, bias, negative_slope, scale): - ctx.save_for_backward(out) - ctx.negative_slope = negative_slope - ctx.scale = scale - - empty = grad_output.new_empty(0) - - grad_input = fused_bias_act_cuda.fused_bias_act( - grad_output, empty, out, 3, 1, negative_slope, scale - ) - - dim = [0] - - if grad_input.ndim > 2: - dim += list(range(2, grad_input.ndim)) - - if bias: - grad_bias = grad_input.sum(dim).detach() - - else: - grad_bias = empty - - return grad_input, grad_bias - - @staticmethod - def backward(ctx, gradgrad_input, gradgrad_bias): - out, = ctx.saved_tensors - gradgrad_out = fused_bias_act_cuda.fused_bias_act( - gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale - ) - - return gradgrad_out, None, None, None, None - - -class FusedLeakyReLUFunction(Function): - @staticmethod - def forward(ctx, input, bias, negative_slope, scale): - empty = input.new_empty(0) - - ctx.bias = bias is not None - - if bias is None: - bias = empty - - out = fused_bias_act_cuda.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) - ctx.save_for_backward(out) - ctx.negative_slope = negative_slope - ctx.scale = scale - - return out - - @staticmethod - def backward(ctx, grad_output): - out, = ctx.saved_tensors - - grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( - grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale - ) - - if not ctx.bias: - grad_bias = None - - return grad_input, grad_bias, None, None - - -class FusedLeakyReLU(nn.Module): - def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): - super().__init__() - - if bias: - self.bias = nn.Parameter(torch.zeros(channel)) - - else: - self.bias = None - - self.negative_slope = negative_slope - self.scale = scale - - def forward(self, input): - return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) - - -def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): - if input.device.type == "cpu": - if bias is not None: - rest_dim = [1] * (input.ndim - bias.ndim - 1) - return ( - F.leaky_relu( - input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 - ) - * scale - ) - - else: - return F.leaky_relu(input, negative_slope=0.2) * scale - - else: - return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/codes/models/archs/stylegan/op/fused_bias_act.cpp b/codes/models/archs/stylegan/op/fused_bias_act.cpp deleted file mode 100644 index 02be898f..00000000 --- a/codes/models/archs/stylegan/op/fused_bias_act.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - - -torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale) { - CHECK_CUDA(input); - CHECK_CUDA(bias); - - return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); -} \ No newline at end of file diff --git a/codes/models/archs/stylegan/op/setup.py b/codes/models/archs/stylegan/op/setup.py deleted file mode 100644 index 2d844fe5..00000000 --- a/codes/models/archs/stylegan/op/setup.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python3 -import os -import torch - -from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -cxx_args = ['-std=c++11'] - -nvcc_args = [ - '-gencode', 'arch=compute_50,code=sm_50', - '-gencode', 'arch=compute_52,code=sm_52', - '-gencode', 'arch=compute_60,code=sm_60', - '-gencode', 'arch=compute_61,code=sm_61', - '-gencode', 'arch=compute_70,code=sm_70', -] - -setup( - name='stylegan2_ops_cuda', - ext_modules=[ - CUDAExtension('fused_bias_act_cuda', [ - 'fused_bias_act.cpp', - 'fused_bias_act_kernel.cu' - ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}), - CUDAExtension('upfirdn2d_cuda', [ - 'upfirdn2d.cpp', - 'upfirdn2d_kernel.cu' - ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/codes/models/archs/stylegan/op/upfirdn2d.cpp b/codes/models/archs/stylegan/op/upfirdn2d.cpp deleted file mode 100644 index d2e633dc..00000000 --- a/codes/models/archs/stylegan/op/upfirdn2d.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - - -torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, - int up_x, int up_y, int down_x, int down_y, - int pad_x0, int pad_x1, int pad_y0, int pad_y1); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, - int up_x, int up_y, int down_x, int down_y, - int pad_x0, int pad_x1, int pad_y0, int pad_y1) { - CHECK_CUDA(input); - CHECK_CUDA(kernel); - - return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); -} \ No newline at end of file diff --git a/codes/models/archs/stylegan/op/upfirdn2d.py b/codes/models/archs/stylegan/op/upfirdn2d.py deleted file mode 100644 index cb00d0bd..00000000 --- a/codes/models/archs/stylegan/op/upfirdn2d.py +++ /dev/null @@ -1,191 +0,0 @@ -import os - -import torch -from torch.nn import functional as F -from torch.autograd import Function -from torch.utils.cpp_extension import load -import upfirdn2d_cuda - - -class UpFirDn2dBackward(Function): - @staticmethod - def forward( - ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size - ): - - up_x, up_y = up - down_x, down_y = down - g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad - - grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) - - grad_input = upfirdn2d_cuda.upfirdn2d( - grad_output, - grad_kernel, - down_x, - down_y, - up_x, - up_y, - g_pad_x0, - g_pad_x1, - g_pad_y0, - g_pad_y1, - ) - grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) - - ctx.save_for_backward(kernel) - - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - ctx.up_x = up_x - ctx.up_y = up_y - ctx.down_x = down_x - ctx.down_y = down_y - ctx.pad_x0 = pad_x0 - ctx.pad_x1 = pad_x1 - ctx.pad_y0 = pad_y0 - ctx.pad_y1 = pad_y1 - ctx.in_size = in_size - ctx.out_size = out_size - - return grad_input - - @staticmethod - def backward(ctx, gradgrad_input): - kernel, = ctx.saved_tensors - - gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) - - gradgrad_out = upfirdn2d_cuda.upfirdn2d( - gradgrad_input, - kernel, - ctx.up_x, - ctx.up_y, - ctx.down_x, - ctx.down_y, - ctx.pad_x0, - ctx.pad_x1, - ctx.pad_y0, - ctx.pad_y1, - ) - # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) - gradgrad_out = gradgrad_out.view( - ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] - ) - - return gradgrad_out, None, None, None, None, None, None, None, None - - -class UpFirDn2d(Function): - @staticmethod - def forward(ctx, input, kernel, up, down, pad): - up_x, up_y = up - down_x, down_y = down - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - kernel_h, kernel_w = kernel.shape - batch, channel, in_h, in_w = input.shape - ctx.in_size = input.shape - - input = input.reshape(-1, in_h, in_w, 1) - - ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - ctx.out_size = (out_h, out_w) - - ctx.up = (up_x, up_y) - ctx.down = (down_x, down_y) - ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) - - g_pad_x0 = kernel_w - pad_x0 - 1 - g_pad_y0 = kernel_h - pad_y0 - 1 - g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 - g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 - - ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) - - out = upfirdn2d_cuda.upfirdn2d( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 - ) - # out = out.view(major, out_h, out_w, minor) - out = out.view(-1, channel, out_h, out_w) - - return out - - @staticmethod - def backward(ctx, grad_output): - kernel, grad_kernel = ctx.saved_tensors - - grad_input = UpFirDn2dBackward.apply( - grad_output, - kernel, - grad_kernel, - ctx.up, - ctx.down, - ctx.pad, - ctx.g_pad, - ctx.in_size, - ctx.out_size, - ) - - return grad_input, None, None, None, None - - -def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - if input.device.type == "cpu": - out = upfirdn2d_native( - input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] - ) - - else: - out = UpFirDn2d.apply( - input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) - ) - - return out - - -def upfirdn2d_native( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 -): - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = input.shape - kernel_h, kernel_w = kernel.shape - - out = input.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/codes/models/archs/stylegan/stylegan2_rosinality.py b/codes/models/archs/stylegan/stylegan2_rosinality.py index a4609be2..8a8b8ce6 100644 --- a/codes/models/archs/stylegan/stylegan2_rosinality.py +++ b/codes/models/archs/stylegan/stylegan2_rosinality.py @@ -8,7 +8,90 @@ from torch import nn from torch.nn import functional as F from torch.autograd import Function -from models.archs.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + +# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the +# "cpu" alternative code and uses that instead for compatibility reasons. +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) +# /end Ops class PixelNorm(nn.Module): diff --git a/codes/scripts/stylegan2/convert_weights_lucidrains.py b/codes/scripts/stylegan2/convert_weights_lucidrains.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/scripts/stylegan2/convert_weights.py b/codes/scripts/stylegan2/convert_weights_rosinality.py similarity index 98% rename from codes/scripts/stylegan2/convert_weights.py rename to codes/scripts/stylegan2/convert_weights_rosinality.py index b165d25e..a2154669 100644 --- a/codes/scripts/stylegan2/convert_weights.py +++ b/codes/scripts/stylegan2/convert_weights_rosinality.py @@ -1,6 +1,5 @@ # Converts from Tensorflow Stylegan2 weights to weights used by this model. # Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py -# Adapted to lucidrains' Stylegan implementation. # # Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo. @@ -277,7 +276,7 @@ if __name__ == "__main__": g = g.to(device) - z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") + z = np.random.RandomState(5).randn(n_sample, 512).astype("float32") with torch.no_grad(): img_pt, _ = g(