Misc options to add support for training stylegan2-rosinality models:

- Allow image_folder_dataset to normalize inbound images
- ExtensibleTrainer can denormalize images on the output path
- Support .webp - an output from LSUN
- Support logistic GAN divergence loss
- Support stylegan2 TF weight extraction for discriminator
- New injector that produces latent noise (with separated paths)
- Modify FID evaluator to be operable with rosinality-style GANs
This commit is contained in:
James Betker 2021-02-08 08:09:21 -07:00
parent e7be4bdff3
commit 784b96c059
11 changed files with 100 additions and 324 deletions

View File

@ -8,6 +8,8 @@ import numpy as np
import torch
import os
from torchvision.transforms import Normalize
from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
from data.image_corruptor import ImageCorruptor
@ -28,6 +30,13 @@ class ImageFolderDataset:
# from the same video source. Search for 'fetch_alt_image' for more info.
self.skip_lq = opt_get(opt, ['skip_lq'], False)
self.disable_flip = opt_get(opt, ['disable_flip'], False)
if 'normalize' in opt.keys():
if opt['normalize'] == 'stylegan2_norm':
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
else:
raise Exception('Unsupported normalize')
else:
self.normalize = None
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list):
self.paths = [self.paths]
@ -128,6 +137,8 @@ class ImageFolderDataset:
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
if self.normalize:
hq = self.normalize(hq)
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}

View File

@ -70,7 +70,7 @@ class expand_greyscale(object):
class Stylegan2Dataset(data.Dataset):
def __init__(self, opt):
super().__init__()
EXTS = ['jpg', 'jpeg', 'png']
EXTS = ['jpg', 'jpeg', 'png', 'webp']
self.folder = opt['path']
self.image_size = opt['target_size']
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]

View File

@ -248,13 +248,12 @@ def gradient_penalty(images, output, weight=10, return_structured_grads=False):
return penalty
def calc_pl_lengths(styles, images):
device = images.device
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
pl_noise = torch.randn_like(images) / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape, device=device),
grad_outputs=torch.ones_like(outputs),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
@ -850,6 +849,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.logistic = opt_get(opt, ['logistic'], False) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used.
def forward(self, net, state):
real_input = state[self.real]
@ -861,10 +861,18 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
D = self.env['discriminators'][self.discriminator]
fake = D(fake_input)
if self.for_gen:
if self.logistic:
return F.softplus(-fake).mean()
else:
return fake.mean()
else:
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
if self.logistic:
rl = F.softplus(-real).mean()
fl = F.softplus(fake).mean()
return fl + rl
else:
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
# Apply gradient penalty. TODO: migrate this elsewhere.

View File

@ -11,6 +11,10 @@ from torch.autograd import Function
# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the
# "cpu" alternative code and uses that instead for compatibility reasons.
from trainer.networks import register_model
from utils.util import opt_get
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
@ -609,12 +613,8 @@ class Generator(nn.Module):
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
@ -741,3 +741,14 @@ class Discriminator(nn.Module):
out = self.final_linear(out)
return out
@register_model
def register_stylegan2_rosinality_gen(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return Generator(**kw)
@register_model
def register_stylegan2_rosinality_disc(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return Discriminator(**kw)

View File

@ -13,17 +13,17 @@ import torch
def main():
split_img = False
opt = {}
opt['n_thread'] = 3
opt['n_thread'] = 5
opt['compression_level'] = 95 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped'
opt['imgsize'] = 1024
opt['bottom_crop'] = .1
opt['keep_folder'] = True
opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\cropped'
opt['imgsize'] = 256
opt['bottom_crop'] = 0
opt['keep_folder'] = False
save_folder = opt['save_folder']
if not osp.exists(save_folder):
@ -83,7 +83,7 @@ class TiledDataset(data.Dataset):
pts = os.path.split(pts[0])
output_folder = osp.join(self.opt['save_folder'], pts[-1])
os.makedirs(output_folder, exist_ok=True)
cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
return None
def __len__(self):

View File

@ -1,292 +0,0 @@
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
# Adapted to lucidrains' Stylegan implementation.
#
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
# THIS DOES NOT CURRENTLY WORK.
# It does transfer all weights from the stylegan model to the lucidrains one, but does not produce correct results.
# The rosinality script this was stolen from has some "odd" intracacies that may be at cause for this: for example
# weight "flipping" in the conv layers which I do not understand. It may also be because I botched some of the mods
# required to make the lucidrains implementation conformant. I'll (maybe) get back to this some day.
import argparse
import os
import sys
import pickle
import math
import torch
import numpy as np
from torchvision import utils
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent
def get_vars(vars, source_name):
net_name = source_name.split('/')[0]
vars_as_tuple_list = vars[net_name]['variables']
result_vars = {}
for t in vars_as_tuple_list:
result_vars[t[0]] = t[1]
return result_vars, source_name.replace(net_name + "/", "")
def get_vars_direct(vars, source_name):
v, n = get_vars(vars, source_name)
return v[n]
def convert_modconv(vars, source_name, target_name, flip=False, numeral=1):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
noise = vars[source_name + "/noise_strength"]
bias = vars[source_name + "/bias"]
dic = {
f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)),
f"to_style{numeral}.weight": mod_weight.transpose((1, 0)),
f"to_style{numeral}.bias": mod_bias + 1,
f"noise{numeral}_scale": np.array([noise]),
f"activation{numeral}.bias": bias,
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
if flip:
dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip(
dic_torch[target_name + f".conv{numeral}.weight"], [2, 3]
)
return dic_torch
def convert_conv(vars, source_name, target_name, bias=True, start=0):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
dic = {"weight": weight.transpose((3, 2, 0, 1))}
if bias:
dic["bias"] = vars[source_name + "/bias"]
dic_torch = {}
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
if bias:
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
return dic_torch
def convert_torgb(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
mod_weight = vars[source_name + "/mod_weight"]
mod_bias = vars[source_name + "/mod_bias"]
bias = vars[source_name + "/bias"]
dic = {
"conv.weight": weight.transpose((3, 2, 0, 1)),
"to_style.weight": mod_weight.transpose((1, 0)),
"to_style.bias": mod_bias + 1,
# "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this?
}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def convert_dense(vars, source_name, target_name):
vars, source_name = get_vars(vars, source_name)
weight = vars[source_name + "/weight"]
bias = vars[source_name + "/bias"]
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
dic_torch = {}
for k, v in dic.items():
dic_torch[target_name + "." + k] = torch.from_numpy(v)
return dic_torch
def update(state_dict, new, strict=True):
for k, v in new.items():
if strict:
if k not in state_dict:
raise KeyError(k + " is not found")
if v.shape != state_dict[k].shape:
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
state_dict[k] = v
def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
conv_i = 1
for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
return statedict
def fill_statedict(state_dict, vars, size):
log_size = int(math.log(size, 2))
for i in range(8):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}"))
update(
state_dict,
{
"gen.initial_block": torch.from_numpy(
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
)
},
)
for i in range(log_size - 1):
reso = 4 * 2 ** i
update(
state_dict,
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"),
)
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1))
for i in range(1, log_size - 1):
reso = 4 * 2 ** i
update(
state_dict,
convert_modconv(
vars,
f"G_synthesis/{reso}x{reso}/Conv0_up",
f"gen.blocks.{i}",
#flip=True, # TODO: why??
numeral=1
),
)
update(
state_dict,
convert_modconv(
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2
),
)
'''
TODO: consider porting this, though I dont think it is necessary.
for i in range(0, (log_size - 2) * 2 + 1):
update(
state_dict,
{
f"noises.noise_{i}": torch.from_numpy(
get_vars_direct(vars, f"G_synthesis/noise{i}")
)
},
)
'''
return state_dict
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
args = parser.parse_args()
sys.path.append('scripts\\stylegan2')
import dnnlib
from dnnlib.tflib.network import generator, gen_ema
with open(args.path, "rb") as f:
pickle.load(f)
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, gen_ema, size)
g.load_state_dict(state_dict, strict=True)
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
if args.gen:
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator, size)
ckpt["g"] = g_train_state
name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(ckpt, name + ".pt")
batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = batch_size.get(size, 25)
g = g.to(device)
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(8)
utils.save_image(
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)

View File

@ -123,7 +123,7 @@ def update(state_dict, new):
def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
update(statedict, convert_conv(vars, f"D/{size}x{size}/FromRGB", "convs.0"))
conv_i = 1
@ -131,25 +131,25 @@ def discriminator_fill_statedict(statedict, vars, size):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
convert_conv(vars, f"D/{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
vars, f"D/{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
vars, f"D/{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
update(statedict, convert_conv(vars, f"D/4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"D/4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"D/Output", "final_linear.1"))
return statedict
@ -249,9 +249,14 @@ if __name__ == "__main__":
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, gen_ema, size)
g.load_state_dict(state_dict, strict=True)
d = Discriminator(size, args.channel_multiplier)
dstate_dict = d.state_dict()
dstate_dict = discriminator_fill_statedict(dstate_dict, discriminator, size)
d.load_state_dict(dstate_dict, strict=True)
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
@ -269,14 +274,16 @@ if __name__ == "__main__":
ckpt["d"] = d_state
name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(state_dict, name + ".pth")
torch.save(state_dict, f"{name}_gen.pth")
torch.save(dstate_dict, f"{name}_disc.pth")
batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = batch_size.get(size, 25)
g = g.to(device)
d = d.to(device)
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
z = np.random.RandomState(1).randn(n_sample, 512).astype("float32")
with torch.no_grad():
img_pt, _ = g(
@ -285,6 +292,8 @@ if __name__ == "__main__":
truncation_latent=latent_avg.to(device),
randomize_noise=False,
)
disc = d(img_pt)
print(disc)
utils.save_image(
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_vqvae3_stage1.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -13,6 +13,8 @@ from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils
from utils.util import opt_get
logger = logging.getLogger('base')
@ -253,6 +255,8 @@ class ExtensibleTrainer(BaseModel):
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
denorm = opt_get(self.opt, ['logger', 'denormalize'], False)
denorm_range = tuple(opt_get(self.opt, ['logger', 'denormalize_range'], None))
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if v not in state.keys():
@ -264,12 +268,12 @@ class ExtensibleTrainer(BaseModel):
if rdbgv.shape[1] > 3:
rdbgv = rdbgv[:, :3, :, :]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)), normalize=denorm, range=denorm_range)
else:
if dbgv.shape[1] > 3:
dbgv = dbgv[:,:3,:,:]
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)), normalize=denorm, range=denorm_range)
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):

View File

@ -8,6 +8,9 @@ from pytorch_fid import fid_score
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
from utils.util import opt_get
class StyleTransferEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env)
@ -16,13 +19,18 @@ class StyleTransferEvaluator(evaluator.Evaluator):
self.im_sz = opt_eval['image_size']
self.fid_real_samples = opt_eval['real_fid_path']
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input.
def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
counter = 0
for i in range(self.batches_per_eval):
if self.noise_type == 'imgnoise':
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
elif self.noise_type == 'stylenoise':
batch = [torch.randn(self.batch_sz, self.latent_dim).to(self.env['device'])]
gen = self.model(batch)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]

View File

@ -6,6 +6,7 @@ from torch.cuda.amp import autocast
from trainer.inject import Injector
from trainer.losses import extract_params_from_state
from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt
@ -386,3 +387,19 @@ class RandomCropInjector(Injector):
def forward(self, state):
return {self.output: self.operator(state[self.input])}
class Stylegan2NoiseInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.mix_prob = opt_get(opt, ['mix_probability'], .9)
self.latent_dim = opt_get(opt, ['latent_dim'], 512)
def make_noise(self, batch, latent_dim, n_noise, device):
return torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
def forward(self, state):
i = state[self.input]
if self.mix_prob > 0 and random.random() < self.mix_prob:
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 2, i.device)}
else:
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 1, i.device)}