Misc options to add support for training stylegan2-rosinality models:
- Allow image_folder_dataset to normalize inbound images - ExtensibleTrainer can denormalize images on the output path - Support .webp - an output from LSUN - Support logistic GAN divergence loss - Support stylegan2 TF weight extraction for discriminator - New injector that produces latent noise (with separated paths) - Modify FID evaluator to be operable with rosinality-style GANs
This commit is contained in:
parent
e7be4bdff3
commit
784b96c059
|
@ -8,6 +8,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
|
||||||
from data import util
|
from data import util
|
||||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||||
from data.image_corruptor import ImageCorruptor
|
from data.image_corruptor import ImageCorruptor
|
||||||
|
@ -28,6 +30,13 @@ class ImageFolderDataset:
|
||||||
# from the same video source. Search for 'fetch_alt_image' for more info.
|
# from the same video source. Search for 'fetch_alt_image' for more info.
|
||||||
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
||||||
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
||||||
|
if 'normalize' in opt.keys():
|
||||||
|
if opt['normalize'] == 'stylegan2_norm':
|
||||||
|
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
else:
|
||||||
|
raise Exception('Unsupported normalize')
|
||||||
|
else:
|
||||||
|
self.normalize = None
|
||||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||||
if not isinstance(self.paths, list):
|
if not isinstance(self.paths, list):
|
||||||
self.paths = [self.paths]
|
self.paths = [self.paths]
|
||||||
|
@ -128,6 +137,8 @@ class ImageFolderDataset:
|
||||||
|
|
||||||
# Convert to torch tensor
|
# Convert to torch tensor
|
||||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||||
|
if self.normalize:
|
||||||
|
hq = self.normalize(hq)
|
||||||
|
|
||||||
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ class expand_greyscale(object):
|
||||||
class Stylegan2Dataset(data.Dataset):
|
class Stylegan2Dataset(data.Dataset):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
EXTS = ['jpg', 'jpeg', 'png']
|
EXTS = ['jpg', 'jpeg', 'png', 'webp']
|
||||||
self.folder = opt['path']
|
self.folder = opt['path']
|
||||||
self.image_size = opt['target_size']
|
self.image_size = opt['target_size']
|
||||||
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
|
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
|
||||||
|
|
|
@ -248,13 +248,12 @@ def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
def calc_pl_lengths(styles, images):
|
def calc_pl_lengths(styles, images):
|
||||||
device = images.device
|
|
||||||
num_pixels = images.shape[2] * images.shape[3]
|
num_pixels = images.shape[2] * images.shape[3]
|
||||||
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
pl_noise = torch.randn_like(images) / math.sqrt(num_pixels)
|
||||||
outputs = (images * pl_noise).sum()
|
outputs = (images * pl_noise).sum()
|
||||||
|
|
||||||
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
||||||
grad_outputs=torch.ones(outputs.shape, device=device),
|
grad_outputs=torch.ones_like(outputs),
|
||||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||||
|
|
||||||
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
||||||
|
@ -850,6 +849,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||||
self.for_gen = opt['gen_loss']
|
self.for_gen = opt['gen_loss']
|
||||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
|
self.logistic = opt_get(opt, ['logistic'], False) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used.
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
real_input = state[self.real]
|
real_input = state[self.real]
|
||||||
|
@ -861,11 +861,19 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||||
D = self.env['discriminators'][self.discriminator]
|
D = self.env['discriminators'][self.discriminator]
|
||||||
fake = D(fake_input)
|
fake = D(fake_input)
|
||||||
if self.for_gen:
|
if self.for_gen:
|
||||||
return fake.mean()
|
if self.logistic:
|
||||||
|
return F.softplus(-fake).mean()
|
||||||
|
else:
|
||||||
|
return fake.mean()
|
||||||
else:
|
else:
|
||||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||||
real = D(real_input)
|
real = D(real_input)
|
||||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
if self.logistic:
|
||||||
|
rl = F.softplus(-real).mean()
|
||||||
|
fl = F.softplus(fake).mean()
|
||||||
|
return fl + rl
|
||||||
|
else:
|
||||||
|
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||||
|
|
||||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
if self.env['step'] % self.gp_frequency == 0:
|
if self.env['step'] % self.gp_frequency == 0:
|
||||||
|
|
|
@ -11,6 +11,10 @@ from torch.autograd import Function
|
||||||
|
|
||||||
# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the
|
# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the
|
||||||
# "cpu" alternative code and uses that instead for compatibility reasons.
|
# "cpu" alternative code and uses that instead for compatibility reasons.
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class FusedLeakyReLU(nn.Module):
|
class FusedLeakyReLU(nn.Module):
|
||||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -609,11 +613,7 @@ class Generator(nn.Module):
|
||||||
|
|
||||||
image = skip
|
image = skip
|
||||||
|
|
||||||
if return_latents:
|
return image, latent
|
||||||
return image, latent
|
|
||||||
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Sequential):
|
class ConvLayer(nn.Sequential):
|
||||||
|
@ -741,3 +741,14 @@ class Discriminator(nn.Module):
|
||||||
out = self.final_linear(out)
|
out = self.final_linear(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_stylegan2_rosinality_gen(opt_net, opt):
|
||||||
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
|
return Generator(**kw)
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_stylegan2_rosinality_disc(opt_net, opt):
|
||||||
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
|
return Discriminator(**kw)
|
||||||
|
|
|
@ -13,17 +13,17 @@ import torch
|
||||||
def main():
|
def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 3
|
opt['n_thread'] = 5
|
||||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
opt['compression_level'] = 95 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
|
opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\cropped'
|
||||||
opt['imgsize'] = 1024
|
opt['imgsize'] = 256
|
||||||
opt['bottom_crop'] = .1
|
opt['bottom_crop'] = 0
|
||||||
opt['keep_folder'] = True
|
opt['keep_folder'] = False
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
if not osp.exists(save_folder):
|
if not osp.exists(save_folder):
|
||||||
|
@ -83,7 +83,7 @@ class TiledDataset(data.Dataset):
|
||||||
pts = os.path.split(pts[0])
|
pts = os.path.split(pts[0])
|
||||||
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -1,292 +0,0 @@
|
||||||
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
|
|
||||||
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
|
|
||||||
# Adapted to lucidrains' Stylegan implementation.
|
|
||||||
#
|
|
||||||
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
|
|
||||||
|
|
||||||
# THIS DOES NOT CURRENTLY WORK.
|
|
||||||
# It does transfer all weights from the stylegan model to the lucidrains one, but does not produce correct results.
|
|
||||||
# The rosinality script this was stolen from has some "odd" intracacies that may be at cause for this: for example
|
|
||||||
# weight "flipping" in the conv layers which I do not understand. It may also be because I botched some of the mods
|
|
||||||
# required to make the lucidrains implementation conformant. I'll (maybe) get back to this some day.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import pickle
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import utils
|
|
||||||
|
|
||||||
|
|
||||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
|
||||||
from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent
|
|
||||||
|
|
||||||
|
|
||||||
def get_vars(vars, source_name):
|
|
||||||
net_name = source_name.split('/')[0]
|
|
||||||
vars_as_tuple_list = vars[net_name]['variables']
|
|
||||||
result_vars = {}
|
|
||||||
for t in vars_as_tuple_list:
|
|
||||||
result_vars[t[0]] = t[1]
|
|
||||||
return result_vars, source_name.replace(net_name + "/", "")
|
|
||||||
|
|
||||||
def get_vars_direct(vars, source_name):
|
|
||||||
v, n = get_vars(vars, source_name)
|
|
||||||
return v[n]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_modconv(vars, source_name, target_name, flip=False, numeral=1):
|
|
||||||
vars, source_name = get_vars(vars, source_name)
|
|
||||||
weight = vars[source_name + "/weight"]
|
|
||||||
mod_weight = vars[source_name + "/mod_weight"]
|
|
||||||
mod_bias = vars[source_name + "/mod_bias"]
|
|
||||||
noise = vars[source_name + "/noise_strength"]
|
|
||||||
bias = vars[source_name + "/bias"]
|
|
||||||
|
|
||||||
dic = {
|
|
||||||
f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)),
|
|
||||||
f"to_style{numeral}.weight": mod_weight.transpose((1, 0)),
|
|
||||||
f"to_style{numeral}.bias": mod_bias + 1,
|
|
||||||
f"noise{numeral}_scale": np.array([noise]),
|
|
||||||
f"activation{numeral}.bias": bias,
|
|
||||||
}
|
|
||||||
|
|
||||||
dic_torch = {}
|
|
||||||
|
|
||||||
for k, v in dic.items():
|
|
||||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
|
||||||
|
|
||||||
if flip:
|
|
||||||
dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip(
|
|
||||||
dic_torch[target_name + f".conv{numeral}.weight"], [2, 3]
|
|
||||||
)
|
|
||||||
|
|
||||||
return dic_torch
|
|
||||||
|
|
||||||
|
|
||||||
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
|
||||||
vars, source_name = get_vars(vars, source_name)
|
|
||||||
weight = vars[source_name + "/weight"]
|
|
||||||
|
|
||||||
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
dic["bias"] = vars[source_name + "/bias"]
|
|
||||||
|
|
||||||
dic_torch = {}
|
|
||||||
|
|
||||||
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
|
||||||
|
|
||||||
return dic_torch
|
|
||||||
|
|
||||||
|
|
||||||
def convert_torgb(vars, source_name, target_name):
|
|
||||||
vars, source_name = get_vars(vars, source_name)
|
|
||||||
weight = vars[source_name + "/weight"]
|
|
||||||
mod_weight = vars[source_name + "/mod_weight"]
|
|
||||||
mod_bias = vars[source_name + "/mod_bias"]
|
|
||||||
bias = vars[source_name + "/bias"]
|
|
||||||
|
|
||||||
dic = {
|
|
||||||
"conv.weight": weight.transpose((3, 2, 0, 1)),
|
|
||||||
"to_style.weight": mod_weight.transpose((1, 0)),
|
|
||||||
"to_style.bias": mod_bias + 1,
|
|
||||||
# "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this?
|
|
||||||
}
|
|
||||||
|
|
||||||
dic_torch = {}
|
|
||||||
|
|
||||||
for k, v in dic.items():
|
|
||||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
|
||||||
|
|
||||||
return dic_torch
|
|
||||||
|
|
||||||
|
|
||||||
def convert_dense(vars, source_name, target_name):
|
|
||||||
vars, source_name = get_vars(vars, source_name)
|
|
||||||
weight = vars[source_name + "/weight"]
|
|
||||||
bias = vars[source_name + "/bias"]
|
|
||||||
|
|
||||||
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
|
||||||
|
|
||||||
dic_torch = {}
|
|
||||||
|
|
||||||
for k, v in dic.items():
|
|
||||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
|
||||||
|
|
||||||
return dic_torch
|
|
||||||
|
|
||||||
|
|
||||||
def update(state_dict, new, strict=True):
|
|
||||||
|
|
||||||
for k, v in new.items():
|
|
||||||
if strict:
|
|
||||||
if k not in state_dict:
|
|
||||||
raise KeyError(k + " is not found")
|
|
||||||
|
|
||||||
if v.shape != state_dict[k].shape:
|
|
||||||
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
|
|
||||||
|
|
||||||
state_dict[k] = v
|
|
||||||
|
|
||||||
|
|
||||||
def discriminator_fill_statedict(statedict, vars, size):
|
|
||||||
log_size = int(math.log(size, 2))
|
|
||||||
|
|
||||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
|
||||||
|
|
||||||
conv_i = 1
|
|
||||||
|
|
||||||
for i in range(log_size - 2, 0, -1):
|
|
||||||
reso = 4 * 2 ** i
|
|
||||||
update(
|
|
||||||
statedict,
|
|
||||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
|
||||||
)
|
|
||||||
update(
|
|
||||||
statedict,
|
|
||||||
convert_conv(
|
|
||||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
update(
|
|
||||||
statedict,
|
|
||||||
convert_conv(
|
|
||||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
|
||||||
),
|
|
||||||
)
|
|
||||||
conv_i += 1
|
|
||||||
|
|
||||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
|
||||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
|
||||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
|
||||||
|
|
||||||
return statedict
|
|
||||||
|
|
||||||
|
|
||||||
def fill_statedict(state_dict, vars, size):
|
|
||||||
log_size = int(math.log(size, 2))
|
|
||||||
|
|
||||||
for i in range(8):
|
|
||||||
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}"))
|
|
||||||
|
|
||||||
update(
|
|
||||||
state_dict,
|
|
||||||
{
|
|
||||||
"gen.initial_block": torch.from_numpy(
|
|
||||||
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(log_size - 1):
|
|
||||||
reso = 4 * 2 ** i
|
|
||||||
update(
|
|
||||||
state_dict,
|
|
||||||
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"),
|
|
||||||
)
|
|
||||||
|
|
||||||
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1))
|
|
||||||
|
|
||||||
for i in range(1, log_size - 1):
|
|
||||||
reso = 4 * 2 ** i
|
|
||||||
update(
|
|
||||||
state_dict,
|
|
||||||
convert_modconv(
|
|
||||||
vars,
|
|
||||||
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
|
||||||
f"gen.blocks.{i}",
|
|
||||||
#flip=True, # TODO: why??
|
|
||||||
numeral=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
update(
|
|
||||||
state_dict,
|
|
||||||
convert_modconv(
|
|
||||||
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
'''
|
|
||||||
TODO: consider porting this, though I dont think it is necessary.
|
|
||||||
for i in range(0, (log_size - 2) * 2 + 1):
|
|
||||||
update(
|
|
||||||
state_dict,
|
|
||||||
{
|
|
||||||
f"noises.noise_{i}": torch.from_numpy(
|
|
||||||
get_vars_direct(vars, f"G_synthesis/noise{i}")
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Tensorflow to pytorch model checkpoint converter"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gen", action="store_true", help="convert the generator weights"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--channel_multiplier",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="channel multiplier factor. config-f = 2, else = 1",
|
|
||||||
)
|
|
||||||
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
sys.path.append('scripts\\stylegan2')
|
|
||||||
|
|
||||||
import dnnlib
|
|
||||||
from dnnlib.tflib.network import generator, gen_ema
|
|
||||||
|
|
||||||
with open(args.path, "rb") as f:
|
|
||||||
pickle.load(f)
|
|
||||||
|
|
||||||
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
|
|
||||||
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
|
|
||||||
|
|
||||||
g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8)
|
|
||||||
state_dict = g.state_dict()
|
|
||||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
|
||||||
|
|
||||||
g.load_state_dict(state_dict, strict=True)
|
|
||||||
|
|
||||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
|
||||||
|
|
||||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
|
||||||
|
|
||||||
if args.gen:
|
|
||||||
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
|
||||||
g_train_state = g_train.state_dict()
|
|
||||||
g_train_state = fill_statedict(g_train_state, generator, size)
|
|
||||||
ckpt["g"] = g_train_state
|
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
|
||||||
torch.save(ckpt, name + ".pt")
|
|
||||||
|
|
||||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
|
||||||
n_sample = batch_size.get(size, 25)
|
|
||||||
|
|
||||||
g = g.to(device)
|
|
||||||
|
|
||||||
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
img_pt, _ = g(8)
|
|
||||||
|
|
||||||
utils.save_image(
|
|
||||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
|
||||||
)
|
|
|
@ -123,7 +123,7 @@ def update(state_dict, new):
|
||||||
def discriminator_fill_statedict(statedict, vars, size):
|
def discriminator_fill_statedict(statedict, vars, size):
|
||||||
log_size = int(math.log(size, 2))
|
log_size = int(math.log(size, 2))
|
||||||
|
|
||||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
update(statedict, convert_conv(vars, f"D/{size}x{size}/FromRGB", "convs.0"))
|
||||||
|
|
||||||
conv_i = 1
|
conv_i = 1
|
||||||
|
|
||||||
|
@ -131,25 +131,25 @@ def discriminator_fill_statedict(statedict, vars, size):
|
||||||
reso = 4 * 2 ** i
|
reso = 4 * 2 ** i
|
||||||
update(
|
update(
|
||||||
statedict,
|
statedict,
|
||||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
convert_conv(vars, f"D/{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||||
)
|
)
|
||||||
update(
|
update(
|
||||||
statedict,
|
statedict,
|
||||||
convert_conv(
|
convert_conv(
|
||||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
vars, f"D/{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
update(
|
update(
|
||||||
statedict,
|
statedict,
|
||||||
convert_conv(
|
convert_conv(
|
||||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
vars, f"D/{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conv_i += 1
|
conv_i += 1
|
||||||
|
|
||||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
update(statedict, convert_conv(vars, f"D/4x4/Conv", "final_conv"))
|
||||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
update(statedict, convert_dense(vars, f"D/4x4/Dense0", "final_linear.0"))
|
||||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
update(statedict, convert_dense(vars, f"D/Output", "final_linear.1"))
|
||||||
|
|
||||||
return statedict
|
return statedict
|
||||||
|
|
||||||
|
@ -249,9 +249,14 @@ if __name__ == "__main__":
|
||||||
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||||
state_dict = g.state_dict()
|
state_dict = g.state_dict()
|
||||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
state_dict = fill_statedict(state_dict, gen_ema, size)
|
||||||
|
|
||||||
g.load_state_dict(state_dict, strict=True)
|
g.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
d = Discriminator(size, args.channel_multiplier)
|
||||||
|
dstate_dict = d.state_dict()
|
||||||
|
dstate_dict = discriminator_fill_statedict(dstate_dict, discriminator, size)
|
||||||
|
d.load_state_dict(dstate_dict, strict=True)
|
||||||
|
|
||||||
|
|
||||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
||||||
|
|
||||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
||||||
|
@ -269,14 +274,16 @@ if __name__ == "__main__":
|
||||||
ckpt["d"] = d_state
|
ckpt["d"] = d_state
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||||
torch.save(state_dict, name + ".pth")
|
torch.save(state_dict, f"{name}_gen.pth")
|
||||||
|
torch.save(dstate_dict, f"{name}_disc.pth")
|
||||||
|
|
||||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||||
n_sample = batch_size.get(size, 25)
|
n_sample = batch_size.get(size, 25)
|
||||||
|
|
||||||
g = g.to(device)
|
g = g.to(device)
|
||||||
|
d = d.to(device)
|
||||||
|
|
||||||
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
|
z = np.random.RandomState(1).randn(n_sample, 512).astype("float32")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
img_pt, _ = g(
|
img_pt, _ = g(
|
||||||
|
@ -285,6 +292,8 @@ if __name__ == "__main__":
|
||||||
truncation_latent=latent_avg.to(device),
|
truncation_latent=latent_avg.to(device),
|
||||||
randomize_noise=False,
|
randomize_noise=False,
|
||||||
)
|
)
|
||||||
|
disc = d(img_pt)
|
||||||
|
print(disc)
|
||||||
|
|
||||||
utils.save_image(
|
utils.save_image(
|
||||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
||||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_vqvae3_stage1.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -13,6 +13,8 @@ from trainer.steps import ConfigurableStep
|
||||||
from trainer.experiments.experiments import get_experiment_for_name
|
from trainer.experiments.experiments import get_experiment_for_name
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
|
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,6 +255,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
|
denorm = opt_get(self.opt, ['logger', 'denormalize'], False)
|
||||||
|
denorm_range = tuple(opt_get(self.opt, ['logger', 'denormalize_range'], None))
|
||||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||||
for v in self.opt['logger']['visuals']:
|
for v in self.opt['logger']['visuals']:
|
||||||
if v not in state.keys():
|
if v not in state.keys():
|
||||||
|
@ -264,12 +268,12 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if rdbgv.shape[1] > 3:
|
if rdbgv.shape[1] > 3:
|
||||||
rdbgv = rdbgv[:, :3, :, :]
|
rdbgv = rdbgv[:, :3, :, :]
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)), normalize=denorm, range=denorm_range)
|
||||||
else:
|
else:
|
||||||
if dbgv.shape[1] > 3:
|
if dbgv.shape[1] > 3:
|
||||||
dbgv = dbgv[:,:3,:,:]
|
dbgv = dbgv[:,:3,:,:]
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)), normalize=denorm, range=denorm_range)
|
||||||
# Some models have their own specific visual debug routines.
|
# Some models have their own specific visual debug routines.
|
||||||
for net_name, net in self.networks.items():
|
for net_name, net in self.networks.items():
|
||||||
if hasattr(net.module, "visual_dbg"):
|
if hasattr(net.module, "visual_dbg"):
|
||||||
|
|
|
@ -8,6 +8,9 @@ from pytorch_fid import fid_score
|
||||||
|
|
||||||
|
|
||||||
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class StyleTransferEvaluator(evaluator.Evaluator):
|
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
super().__init__(model, opt_eval, env)
|
super().__init__(model, opt_eval, env)
|
||||||
|
@ -16,13 +19,18 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
||||||
self.im_sz = opt_eval['image_size']
|
self.im_sz = opt_eval['image_size']
|
||||||
self.fid_real_samples = opt_eval['real_fid_path']
|
self.fid_real_samples = opt_eval['real_fid_path']
|
||||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||||
|
self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
|
||||||
|
self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input.
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
|
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
|
||||||
os.makedirs(fid_fake_path, exist_ok=True)
|
os.makedirs(fid_fake_path, exist_ok=True)
|
||||||
counter = 0
|
counter = 0
|
||||||
for i in range(self.batches_per_eval):
|
for i in range(self.batches_per_eval):
|
||||||
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
if self.noise_type == 'imgnoise':
|
||||||
|
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||||
|
elif self.noise_type == 'stylenoise':
|
||||||
|
batch = [torch.randn(self.batch_sz, self.latent_dim).to(self.env['device'])]
|
||||||
gen = self.model(batch)
|
gen = self.model(batch)
|
||||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||||
gen = [gen]
|
gen = [gen]
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from trainer.losses import extract_params_from_state
|
from trainer.losses import extract_params_from_state
|
||||||
|
from utils.util import opt_get
|
||||||
from utils.weight_scheduler import get_scheduler_for_opt
|
from utils.weight_scheduler import get_scheduler_for_opt
|
||||||
|
|
||||||
|
|
||||||
|
@ -386,3 +387,19 @@ class RandomCropInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
return {self.output: self.operator(state[self.input])}
|
return {self.output: self.operator(state[self.input])}
|
||||||
|
|
||||||
|
|
||||||
|
class Stylegan2NoiseInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.mix_prob = opt_get(opt, ['mix_probability'], .9)
|
||||||
|
self.latent_dim = opt_get(opt, ['latent_dim'], 512)
|
||||||
|
|
||||||
|
def make_noise(self, batch, latent_dim, n_noise, device):
|
||||||
|
return torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
i = state[self.input]
|
||||||
|
if self.mix_prob > 0 and random.random() < self.mix_prob:
|
||||||
|
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 2, i.device)}
|
||||||
|
else:
|
||||||
|
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 1, i.device)}
|
Loading…
Reference in New Issue
Block a user