DL-Art-School/codes/scripts/stylegan2/convert_weights_lucidrains.py
2020-12-18 09:55:38 -07:00

293 lines
8.7 KiB
Python

# Converts from Tensorflow Stylegan2 weights to weights used by this model.
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
# Adapted to lucidrains' Stylegan implementation.
#
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
# THIS DOES NOT CURRENTLY WORK.
# It does transfer all weights from the stylegan model to the lucidrains one, but does not produce correct results.
# The rosinality script this was stolen from has some "odd" intracacies that may be at cause for this: for example
# weight "flipping" in the conv layers which I do not understand. It may also be because I botched some of the mods
# required to make the lucidrains implementation conformant. I'll (maybe) get back to this some day.
import argparse
import os
import sys
import pickle
import math
import torch
import numpy as np
from torchvision import utils
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent
def get_vars(vars, source_name):
net_name = source_name.split('/')[0]
vars_as_tuple_list = vars[net_name]['variables']
result_vars = {}
for t in vars_as_tuple_list:
result_vars[t[0]] = t[1]
return result_vars, source_name.replace(net_name + "/", "")
def get_vars_direct(vars, source_name):
v, n = get_vars(vars, source_name)
return v[n]
def convert_modconv(vars, source_name, target_name, flip=False, numeral=1):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
noise = vars[source_name + "/noise_strength"]
bias = vars[source_name + "/bias"]
dic = {
f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)),
f"to_style{numeral}.weight": mod_weight.transpose((1, 0)),
f"to_style{numeral}.bias": mod_bias + 1,
f"noise{numeral}_scale": np.array([noise]),
f"activation{numeral}.bias": bias,
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
if flip:
dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip(
dic_torch[target_name + f".conv{numeral}.weight"], [2, 3]
)
return dic_torch
def convert_conv(vars, source_name, target_name, bias=True, start=0):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
dic = {"weight": weight.transpose((3, 2, 0, 1))}
if bias:
dic["bias"] = vars[source_name + "/bias"]
dic_torch = {}
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
if bias:
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
return dic_torch
def convert_torgb(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
bias = vars[source_name + "/bias"]
dic = {
"conv.weight": weight.transpose((3, 2, 0, 1)),
"to_style.weight": mod_weight.transpose((1, 0)),
"to_style.bias": mod_bias + 1,
# "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this?
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def convert_dense(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
bias = vars[source_name + "/bias"]
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def update(state_dict, new, strict=True):
for k, v in new.items():
if strict:
if k not in state_dict:
raise KeyError(k + " is not found")
if v.shape != state_dict[k].shape:
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
state_dict[k] = v
def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
conv_i = 1
for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
return statedict
def fill_statedict(state_dict, vars, size):
log_size = int(math.log(size, 2))
for i in range(8):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}"))
update(
state_dict,
{
"gen.initial_block": torch.from_numpy(
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
)
},
)
for i in range(log_size - 1):
reso = 4 * 2 ** i
update(
state_dict,
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"),
)
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1))
for i in range(1, log_size - 1):
reso = 4 * 2 ** i
update(
state_dict,
convert_modconv(
vars,
f"G_synthesis/{reso}x{reso}/Conv0_up",
f"gen.blocks.{i}",
#flip=True, # TODO: why??
numeral=1
),
)
update(
state_dict,
convert_modconv(
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2
),
)
'''
TODO: consider porting this, though I dont think it is necessary.
for i in range(0, (log_size - 2) * 2 + 1):
update(
state_dict,
{
f"noises.noise_{i}": torch.from_numpy(
get_vars_direct(vars, f"G_synthesis/noise{i}")
)
},
)
'''
return state_dict
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
args = parser.parse_args()
sys.path.append('scripts\\stylegan2')
import dnnlib
from dnnlib.tflib.network import generator, gen_ema
with open(args.path, "rb") as f:
pickle.load(f)
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, gen_ema, size)
g.load_state_dict(state_dict, strict=True)
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
if args.gen:
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator, size)
ckpt["g"] = g_train_state
name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(ckpt, name + ".pt")
batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = batch_size.get(size, 25)
g = g.to(device)
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(8)
utils.save_image(
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)