From e838c6e75b1fbacbb700bd7476b15b6a30c2b9ee Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 17 Dec 2020 14:18:46 -0700 Subject: [PATCH] Rosinality stylegan2 port --- codes/data/stylegan2_dataset.py | 2 +- codes/models/archs/stylegan/__init__.py | 5 +- codes/models/archs/stylegan/op/__init__.py | 2 + codes/models/archs/stylegan/op/fused_act.py | 110 +++ .../archs/stylegan/op/fused_bias_act.cpp | 21 + codes/models/archs/stylegan/op/setup.py | 32 + codes/models/archs/stylegan/op/upfirdn2d.cpp | 23 + codes/models/archs/stylegan/op/upfirdn2d.py | 191 +++++ .../{stylegan2.py => stylegan2_lucidrains.py} | 8 +- .../archs/stylegan/stylegan2_rosinality.py | 660 ++++++++++++++++++ .../archs/stylegan/stylegan2_unet_disc.py | 243 ------- codes/models/networks.py | 6 +- codes/models/steps/losses.py | 2 +- codes/models/steps/tecogan_losses.py | 2 +- codes/scripts/stylegan2/convert_weights.py | 292 ++++++++ .../scripts/stylegan2/dnnlib/tflib/network.py | 17 + 16 files changed, 1357 insertions(+), 259 deletions(-) create mode 100644 codes/models/archs/stylegan/op/__init__.py create mode 100644 codes/models/archs/stylegan/op/fused_act.py create mode 100644 codes/models/archs/stylegan/op/fused_bias_act.cpp create mode 100644 codes/models/archs/stylegan/op/setup.py create mode 100644 codes/models/archs/stylegan/op/upfirdn2d.cpp create mode 100644 codes/models/archs/stylegan/op/upfirdn2d.py rename codes/models/archs/stylegan/{stylegan2.py => stylegan2_lucidrains.py} (98%) create mode 100644 codes/models/archs/stylegan/stylegan2_rosinality.py delete mode 100644 codes/models/archs/stylegan/stylegan2_unet_disc.py create mode 100644 codes/scripts/stylegan2/convert_weights.py create mode 100644 codes/scripts/stylegan2/dnnlib/tflib/network.py diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index b00f591d..c52045d8 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms import torch.nn as nn from pathlib import Path -import models.archs.stylegan.stylegan2 as sg2 +import models.archs.stylegan.stylegan2_lucidrains as sg2 def convert_transparent_to_rgb(image): diff --git a/codes/models/archs/stylegan/__init__.py b/codes/models/archs/stylegan/__init__.py index c38166f4..ab92d1d5 100644 --- a/codes/models/archs/stylegan/__init__.py +++ b/codes/models/archs/stylegan/__init__.py @@ -1,5 +1,4 @@ -import models.archs.stylegan.stylegan2 as stylegan2 -import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet +import models.archs.stylegan.stylegan2_lucidrains as stylegan2 def create_stylegan2_loss(opt_loss, env): @@ -8,7 +7,5 @@ def create_stylegan2_loss(opt_loss, env): return stylegan2.StyleGan2DivergenceLoss(opt_loss, env) elif type == 'stylegan2_pathlen': return stylegan2.StyleGan2PathLengthLoss(opt_loss, env) - elif type == 'stylegan2_unet_divergence': - return stylegan2_unet.StyleGan2UnetDivergenceLoss(opt_loss, env) else: raise NotImplementedError \ No newline at end of file diff --git a/codes/models/archs/stylegan/op/__init__.py b/codes/models/archs/stylegan/op/__init__.py new file mode 100644 index 00000000..d0918d92 --- /dev/null +++ b/codes/models/archs/stylegan/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/codes/models/archs/stylegan/op/fused_act.py b/codes/models/archs/stylegan/op/fused_act.py new file mode 100644 index 00000000..5d9bd348 --- /dev/null +++ b/codes/models/archs/stylegan/op/fused_act.py @@ -0,0 +1,110 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load +import fused_bias_act_cuda + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, bias, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_bias_act_cuda.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + if bias: + grad_bias = grad_input.sum(dim).detach() + + else: + grad_bias = empty + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_bias_act_cuda.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + ctx.bias = bias is not None + + if bias is None: + bias = empty + + out = fused_bias_act_cuda.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale + ) + + if not ctx.bias: + grad_bias = None + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if input.device.type == "cpu": + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/codes/models/archs/stylegan/op/fused_bias_act.cpp b/codes/models/archs/stylegan/op/fused_bias_act.cpp new file mode 100644 index 00000000..02be898f --- /dev/null +++ b/codes/models/archs/stylegan/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/codes/models/archs/stylegan/op/setup.py b/codes/models/archs/stylegan/op/setup.py new file mode 100644 index 00000000..2d844fe5 --- /dev/null +++ b/codes/models/archs/stylegan/op/setup.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-std=c++11'] + +nvcc_args = [ + '-gencode', 'arch=compute_50,code=sm_50', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', +] + +setup( + name='stylegan2_ops_cuda', + ext_modules=[ + CUDAExtension('fused_bias_act_cuda', [ + 'fused_bias_act.cpp', + 'fused_bias_act_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}), + CUDAExtension('upfirdn2d_cuda', [ + 'upfirdn2d.cpp', + 'upfirdn2d_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/codes/models/archs/stylegan/op/upfirdn2d.cpp b/codes/models/archs/stylegan/op/upfirdn2d.cpp new file mode 100644 index 00000000..d2e633dc --- /dev/null +++ b/codes/models/archs/stylegan/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/codes/models/archs/stylegan/op/upfirdn2d.py b/codes/models/archs/stylegan/op/upfirdn2d.py new file mode 100644 index 00000000..cb00d0bd --- /dev/null +++ b/codes/models/archs/stylegan/op/upfirdn2d.py @@ -0,0 +1,191 @@ +import os + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load +import upfirdn2d_cuda + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_cuda.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_cuda.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_cuda.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/codes/models/archs/stylegan/stylegan2.py b/codes/models/archs/stylegan/stylegan2_lucidrains.py similarity index 98% rename from codes/models/archs/stylegan/stylegan2.py rename to codes/models/archs/stylegan/stylegan2_lucidrains.py index 8cd9365c..b1afb619 100644 --- a/codes/models/archs/stylegan/stylegan2.py +++ b/codes/models/archs/stylegan/stylegan2_lucidrains.py @@ -858,7 +858,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: - from models.archs.stylegan.stylegan2 import gradient_penalty + from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty gp = gradient_penalty(real_input, real) self.metrics.append(("gradient_penalty", gp.clone().detach())) divergence_loss = divergence_loss + gp @@ -873,17 +873,17 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss): self.w_styles = opt['w_styles'] self.gen = opt['gen'] self.pl_mean = None - from models.archs.stylegan.stylegan2 import EMA + from models.archs.stylegan.stylegan2_lucidrains import EMA self.pl_length_ma = EMA(.99) def forward(self, net, state): w_styles = state[self.w_styles] gen = state[self.gen] - from models.archs.stylegan.stylegan2 import calc_pl_lengths + from models.archs.stylegan.stylegan2_lucidrains import calc_pl_lengths pl_lengths = calc_pl_lengths(w_styles, gen) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) - from models.archs.stylegan.stylegan2 import is_empty + from models.archs.stylegan.stylegan2_lucidrains import is_empty if not is_empty(self.pl_mean): pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() if not torch.isnan(pl_loss): diff --git a/codes/models/archs/stylegan/stylegan2_rosinality.py b/codes/models/archs/stylegan/stylegan2_rosinality.py new file mode 100644 index 00000000..a4609be2 --- /dev/null +++ b/codes/models/archs/stylegan/stylegan2_rosinality.py @@ -0,0 +1,660 @@ +import math +import random +import functools +import operator + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +from models.archs.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + ) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})" + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/codes/models/archs/stylegan/stylegan2_unet_disc.py b/codes/models/archs/stylegan/stylegan2_unet_disc.py deleted file mode 100644 index 4b319879..00000000 --- a/codes/models/archs/stylegan/stylegan2_unet_disc.py +++ /dev/null @@ -1,243 +0,0 @@ -from functools import partial -from math import log2 -from random import random -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F -import models.archs.stylegan.stylegan2 as sg2 -import models.steps.losses as L - - -def leaky_relu(p=0.2): - return nn.LeakyReLU(p) - - -def double_conv(chan_in, chan_out): - return nn.Sequential( - nn.Conv2d(chan_in, chan_out, 3, padding=1), - leaky_relu(), - nn.Conv2d(chan_out, chan_out, 3, padding=1), - leaky_relu() - ) - - -class Flatten(nn.Module): - def __init__(self, index): - super().__init__() - self.index = index - def forward(self, x): - return x.flatten(self.index) - - -class DownBlock(nn.Module): - def __init__(self, input_channels, filters, downsample=True): - super().__init__() - self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) - - self.net = double_conv(input_channels, filters) - self.down = nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) if downsample else None - - def forward(self, x): - res = self.conv_res(x) - x = self.net(x) - unet_res = x - - if self.down is not None: - x = self.down(x) - - x = x + res - return x, unet_res - - -class UpBlock(nn.Module): - def __init__(self, input_channels, filters): - super().__init__() - self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride = 2) - self.net = double_conv(input_channels, filters) - self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) - self.input_channels = input_channels - self.filters = filters - - def forward(self, x, res): - *_, h, w = x.shape - conv_res = self.conv_res(x, output_size = (h * 2, w * 2)) - x = self.up(x) - x = torch.cat((x, res), dim=1) - x = self.net(x) - x = x + conv_res - return x - - -class StyleGan2UnetDiscriminator(nn.Module): - def __init__(self, image_size, network_capacity = 16, fmap_max = 512, input_filters=3): - super().__init__() - num_layers = int(log2(image_size) - 3) - - blocks = [] - filters = [input_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)] - - set_fmap_max = partial(min, fmap_max) - filters = list(map(set_fmap_max, filters)) - filters[-1] = filters[-2] - - chan_in_out = list(zip(filters[:-1], filters[1:])) - chan_in_out = list(map(list, chan_in_out)) - - down_blocks = [] - attn_blocks = [] - - for ind, (in_chan, out_chan) in enumerate(chan_in_out): - num_layer = ind + 1 - is_not_last = ind != (len(chan_in_out) - 1) - - block = DownBlock(in_chan, out_chan, downsample = is_not_last) - down_blocks.append(block) - - attn_fn = sg2.attn_and_ff(out_chan) - attn_blocks.append(attn_fn) - - self.down_blocks = nn.ModuleList(down_blocks) - self.attn_blocks = nn.ModuleList(attn_blocks) - - last_chan = filters[-1] - - self.to_logit = nn.Sequential( - leaky_relu(), - nn.AvgPool2d(image_size // (2 ** num_layers)), - Flatten(1), - nn.Linear(last_chan, 1) - ) - - self.conv = double_conv(last_chan, last_chan) - - dec_chan_in_out = chan_in_out[:-1][::-1] - self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out))) - self.conv_out = nn.Conv2d(input_filters, 1, 1) - - def forward(self, x): - b, *_ = x.shape - - residuals = [] - - for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks): - x, unet_res = down_block(x) - residuals.append(unet_res) - - if attn_block is not None: - x = attn_block(x) - - x = self.conv(x) + x - enc_out = self.to_logit(x) - - for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]): - x = up_block(x, res) - - dec_out = self.conv_out(x) - return dec_out, enc_out - - -def warmup(start, end, max_steps, current_step): - if current_step > max_steps: - return end - return (end - start) * (current_step / max_steps) + start - - -def mask_src_tgt(source, target, mask): - return source * mask + (1 - mask) * target - - -def cutmix(source, target, coors, alpha = 1.): - source, target = map(torch.clone, (source, target)) - ((y0, y1), (x0, x1)), _ = coors - source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1] - return source - - -def cutmix_coordinates(height, width, alpha = 1.): - lam = np.random.beta(alpha, alpha) - - cx = np.random.uniform(0, width) - cy = np.random.uniform(0, height) - w = width * np.sqrt(1 - lam) - h = height * np.sqrt(1 - lam) - x0 = int(np.round(max(cx - w / 2, 0))) - x1 = int(np.round(min(cx + w / 2, width))) - y0 = int(np.round(max(cy - h / 2, 0))) - y1 = int(np.round(min(cy + h / 2, height))) - - return ((y0, y1), (x0, x1)), lam - - -class StyleGan2UnetDivergenceLoss(L.ConfigurableLoss): - def __init__(self, opt, env): - super().__init__(opt, env) - self.real = opt['real'] - self.fake = opt['fake'] - self.discriminator = opt['discriminator'] - self.for_gen = opt['gen_loss'] - self.gp_frequency = opt['gradient_penalty_frequency'] - self.noise = opt['noise'] if 'noise' in opt.keys() else 0 - self.image_size = opt['image_size'] - self.cr_weight = .2 - - def forward(self, net, state): - real_input = state[self.real] - fake_input = state[self.fake] - if self.noise != 0: - fake_input = fake_input + torch.rand_like(fake_input) * self.noise - real_input = real_input + torch.rand_like(real_input) * self.noise - - D = self.env['discriminators'][self.discriminator] - fake_dec, fake_enc = D(fake_input) - fake_aug_images = D.module.aug_images - if self.for_gen: - return fake_enc.mean() + F.relu(1 + fake_dec).mean() - else: - dec_loss_coef = warmup(0, 1., 30000, self.env['step']) - cutmix_prob = warmup(0, 0.25, 30000, self.env['step']) - apply_cutmix = random() < cutmix_prob - - real_input.requires_grad_() # <-- Needed to compute gradients on the input. - real_dec, real_enc = D(real_input) - real_aug_images = D.module.aug_images - enc_divergence = (F.relu(1 + real_enc) + F.relu(1 - fake_enc)).mean() - dec_divergence = (F.relu(1 + real_dec) + F.relu(1 - fake_dec)).mean() - disc_loss = enc_divergence + dec_divergence * dec_loss_coef - - if apply_cutmix: - mask = cutmix( - torch.ones_like(real_dec), - torch.zeros_like(real_dec), - cutmix_coordinates(self.image_size, self.image_size) - ) - - if random() > 0.5: - mask = 1 - mask - - cutmix_images = mask_src_tgt(real_aug_images, fake_aug_images, mask) - cutmix_dec_out, cutmix_enc_out = D.module.D(cutmix_images) # Bypass implied augmentor - hence D.module.D - - cutmix_enc_divergence = F.relu(1 - cutmix_enc_out).mean() - cutmix_dec_divergence = F.relu(1 + (mask * 2 - 1) * cutmix_dec_out).mean() - disc_loss = disc_loss + cutmix_enc_divergence + cutmix_dec_divergence - - cr_cutmix_dec_out = mask_src_tgt(real_dec, fake_dec, mask) - cr_loss = F.mse_loss(cutmix_dec_out, cr_cutmix_dec_out) * self.cr_weight - self.last_cr_loss = cr_loss.clone().detach().item() - - disc_loss = disc_loss + cr_loss * dec_loss_coef - - # Apply gradient penalty. TODO: migrate this elsewhere. - if self.env['step'] % self.gp_frequency == 0: - from models.archs.stylegan.stylegan2 import gradient_penalty - if random() < .5: - gp = gradient_penalty(real_input, real_enc) - else: - gp = gradient_penalty(real_input, real_dec) * dec_loss_coef - self.metrics.append(("gradient_penalty", gp.clone().detach())) - disc_loss = disc_loss + gp - - real_input.requires_grad_(requires_grad=False) - return disc_loss diff --git a/codes/models/networks.py b/codes/models/networks.py index bc947f2a..056a1f19 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -6,8 +6,7 @@ import munch import torch import torchvision from munch import munchify -import models.archs.stylegan.stylegan2 as stylegan2 -import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet +import models.archs.stylegan.stylegan2_lucidrains as stylegan2 import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.RRDBNet_arch as RRDBNet_arch @@ -228,9 +227,6 @@ def define_D_net(opt_net, img_sz=None, wrap=False): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn) netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) - elif which_model == "stylegan2_unet": - disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc']) - netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) elif which_model == "rrdb_disc": netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3) else: diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 84e12e3f..77d748d8 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -312,7 +312,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): if self.gradient_penalty: # Apply gradient penalty. TODO: migrate this elsewhere. - from models.archs.stylegan.stylegan2 import gradient_penalty + from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. gp = gradient_penalty(real[0], d_real) self.metrics.append(("gradient_penalty", gp.clone().detach())) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 70748328..7f921204 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -1,6 +1,6 @@ from torch.cuda.amp import autocast -from models.archs.stylegan.stylegan2 import gradient_penalty +from models.archs.stylegan.stylegan2_lucidrains import gradient_penalty from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.steps.injectors import Injector diff --git a/codes/scripts/stylegan2/convert_weights.py b/codes/scripts/stylegan2/convert_weights.py new file mode 100644 index 00000000..b165d25e --- /dev/null +++ b/codes/scripts/stylegan2/convert_weights.py @@ -0,0 +1,292 @@ +# Converts from Tensorflow Stylegan2 weights to weights used by this model. +# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py +# Adapted to lucidrains' Stylegan implementation. +# +# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo. + +import argparse +import os +import sys +import pickle +import math + +import torch +import numpy as np +from torchvision import utils + +from models.archs.stylegan.stylegan2_rosinality import Generator, Discriminator + + +# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter. +def get_vars(vars, source_name): + net_name = source_name.split('/')[0] + vars_as_tuple_list = vars[net_name]['variables'] + result_vars = {} + for t in vars_as_tuple_list: + result_vars[t[0]] = t[1] + return result_vars, source_name.replace(net_name + "/", "") + +def get_vars_direct(vars, source_name): + v, n = get_vars(vars, source_name) + return v[n] + + +def convert_modconv(vars, source_name, target_name, flip=False): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + mod_weight = vars[source_name + "/mod_weight"] + mod_bias = vars[source_name + "/mod_bias"] + noise = vars[source_name + "/noise_strength"] + bias = vars[source_name + "/bias"] + + dic = { + "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), + "conv.modulation.weight": mod_weight.transpose((1, 0)), + "conv.modulation.bias": mod_bias + 1, + "noise.weight": np.array([noise]), + "activate.bias": bias, + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + if flip: + dic_torch[target_name + ".conv.weight"] = torch.flip( + dic_torch[target_name + ".conv.weight"], [3, 4] + ) + + return dic_torch + + +def convert_conv(vars, source_name, target_name, bias=True, start=0): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + + dic = {"weight": weight.transpose((3, 2, 0, 1))} + + if bias: + dic["bias"] = vars[source_name + "/bias"] + + dic_torch = {} + + dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) + + if bias: + dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) + + return dic_torch + + +def convert_torgb(vars, source_name, target_name): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + mod_weight = vars[source_name + "/mod_weight"] + mod_bias = vars[source_name + "/mod_bias"] + bias = vars[source_name + "/bias"] + + dic = { + "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), + "conv.modulation.weight": mod_weight.transpose((1, 0)), + "conv.modulation.bias": mod_bias + 1, + "bias": bias.reshape((1, 3, 1, 1)), + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + return dic_torch + + +def convert_dense(vars, source_name, target_name): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + bias = vars[source_name + "/bias"] + + dic = {"weight": weight.transpose((1, 0)), "bias": bias} + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + return dic_torch + + +def update(state_dict, new): + for k, v in new.items(): + state_dict[k] = v + + +def discriminator_fill_statedict(statedict, vars, size): + log_size = int(math.log(size, 2)) + + update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) + + conv_i = 1 + + for i in range(log_size - 2, 0, -1): + reso = 4 * 2 ** i + update( + statedict, + convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), + ) + update( + statedict, + convert_conv( + vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 + ), + ) + update( + statedict, + convert_conv( + vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False + ), + ) + conv_i += 1 + + update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) + update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) + update(statedict, convert_dense(vars, f"Output", "final_linear.1")) + + return statedict + + +def fill_statedict(state_dict, vars, size): + log_size = int(math.log(size, 2)) + + for i in range(8): + update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) + + update( + state_dict, + { + "input.input": torch.from_numpy( + get_vars_direct(vars, "G_synthesis/4x4/Const/const") + ) + }, + ) + + update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) + + for i in range(log_size - 2): + reso = 4 * 2 ** (i + 1) + update( + state_dict, + convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), + ) + + update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) + + conv_i = 0 + + for i in range(log_size - 2): + reso = 4 * 2 ** (i + 1) + update( + state_dict, + convert_modconv( + vars, + f"G_synthesis/{reso}x{reso}/Conv0_up", + f"convs.{conv_i}", + flip=True, + ), + ) + update( + state_dict, + convert_modconv( + vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}" + ), + ) + conv_i += 2 + + for i in range(0, (log_size - 2) * 2 + 1): + update( + state_dict, + { + f"noises.noise_{i}": torch.from_numpy( + get_vars_direct(vars, f"G_synthesis/noise{i}") + ) + }, + ) + + return state_dict + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser( + description="Tensorflow to pytorch model checkpoint converter" + ) + parser.add_argument( + "--gen", action="store_true", help="convert the generator weights" + ) + parser.add_argument( + "--disc", action="store_true", help="convert the discriminator weights" + ) + parser.add_argument( + "--channel_multiplier", + type=int, + default=2, + help="channel multiplier factor. config-f = 2, else = 1", + ) + parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") + + args = parser.parse_args() + sys.path.append('scripts\\stylegan2') + + import dnnlib + from dnnlib.tflib.network import generator, discriminator, gen_ema + + with open(args.path, "rb") as f: + pickle.load(f) + + # Weight names are ordered by size. The last name will be something like '1024x1024/'. We just need to grab that first number. + size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0]) + + g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) + state_dict = g.state_dict() + state_dict = fill_statedict(state_dict, gen_ema, size) + + g.load_state_dict(state_dict, strict=True) + + latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg")) + + ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} + + if args.gen: + g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) + g_train_state = g_train.state_dict() + g_train_state = fill_statedict(g_train_state, generator, size) + ckpt["g"] = g_train_state + + if args.disc: + disc = Discriminator(size, channel_multiplier=args.channel_multiplier) + d_state = disc.state_dict() + d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) + ckpt["d"] = d_state + + name = os.path.splitext(os.path.basename(args.path))[0] + torch.save(ckpt, name + ".pt") + + batch_size = {256: 16, 512: 9, 1024: 4} + n_sample = batch_size.get(size, 25) + + g = g.to(device) + + z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") + + with torch.no_grad(): + img_pt, _ = g( + [torch.from_numpy(z).to(device)], + truncation=0.5, + truncation_latent=latent_avg.to(device), + randomize_noise=False, + ) + + utils.save_image( + img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) + ) diff --git a/codes/scripts/stylegan2/dnnlib/tflib/network.py b/codes/scripts/stylegan2/dnnlib/tflib/network.py new file mode 100644 index 00000000..63f4544e --- /dev/null +++ b/codes/scripts/stylegan2/dnnlib/tflib/network.py @@ -0,0 +1,17 @@ +# Pretends to be the stylegan2 Network class for intercepting pickle load requests. +# Horrible hack. Please don't judge me. + +# Globals for storing these networks because I have no idea how pickle is doing this internally. +generator, discriminator, gen_ema = {}, {}, {} + +class Network: + def __setstate__(self, state: dict) -> None: + global generator, discriminator, gen_ema + name = state['name'] + if name in ['G_synthesis', 'G_mapping', 'G', 'G_main']: + if name != 'G' and name not in generator.keys(): + generator[name] = state + else: + gen_ema[name] = state + elif name in ['D']: + discriminator[name] = state