Misc options to add support for training stylegan2-rosinality models:
- Allow image_folder_dataset to normalize inbound images - ExtensibleTrainer can denormalize images on the output path - Support .webp - an output from LSUN - Support logistic GAN divergence loss - Support stylegan2 TF weight extraction for discriminator - New injector that produces latent noise (with separated paths) - Modify FID evaluator to be operable with rosinality-style GANs
This commit is contained in:
parent
e7be4bdff3
commit
784b96c059
|
@ -8,6 +8,8 @@ import numpy as np
|
|||
import torch
|
||||
import os
|
||||
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
from data import util
|
||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||
from data.image_corruptor import ImageCorruptor
|
||||
|
@ -28,6 +30,13 @@ class ImageFolderDataset:
|
|||
# from the same video source. Search for 'fetch_alt_image' for more info.
|
||||
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
||||
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
||||
if 'normalize' in opt.keys():
|
||||
if opt['normalize'] == 'stylegan2_norm':
|
||||
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
else:
|
||||
raise Exception('Unsupported normalize')
|
||||
else:
|
||||
self.normalize = None
|
||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||
if not isinstance(self.paths, list):
|
||||
self.paths = [self.paths]
|
||||
|
@ -128,6 +137,8 @@ class ImageFolderDataset:
|
|||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||
if self.normalize:
|
||||
hq = self.normalize(hq)
|
||||
|
||||
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ class expand_greyscale(object):
|
|||
class Stylegan2Dataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
EXTS = ['jpg', 'jpeg', 'png']
|
||||
EXTS = ['jpg', 'jpeg', 'png', 'webp']
|
||||
self.folder = opt['path']
|
||||
self.image_size = opt['target_size']
|
||||
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
|
||||
|
|
|
@ -248,13 +248,12 @@ def gradient_penalty(images, output, weight=10, return_structured_grads=False):
|
|||
return penalty
|
||||
|
||||
def calc_pl_lengths(styles, images):
|
||||
device = images.device
|
||||
num_pixels = images.shape[2] * images.shape[3]
|
||||
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
||||
pl_noise = torch.randn_like(images) / math.sqrt(num_pixels)
|
||||
outputs = (images * pl_noise).sum()
|
||||
|
||||
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
||||
grad_outputs=torch.ones(outputs.shape, device=device),
|
||||
grad_outputs=torch.ones_like(outputs),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
||||
|
@ -850,6 +849,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
|||
self.for_gen = opt['gen_loss']
|
||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
self.logistic = opt_get(opt, ['logistic'], False) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used.
|
||||
|
||||
def forward(self, net, state):
|
||||
real_input = state[self.real]
|
||||
|
@ -861,11 +861,19 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
|||
D = self.env['discriminators'][self.discriminator]
|
||||
fake = D(fake_input)
|
||||
if self.for_gen:
|
||||
return fake.mean()
|
||||
if self.logistic:
|
||||
return F.softplus(-fake).mean()
|
||||
else:
|
||||
return fake.mean()
|
||||
else:
|
||||
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||
real = D(real_input)
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
if self.logistic:
|
||||
rl = F.softplus(-real).mean()
|
||||
fl = F.softplus(fake).mean()
|
||||
return fl + rl
|
||||
else:
|
||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
if self.env['step'] % self.gp_frequency == 0:
|
||||
|
|
|
@ -11,6 +11,10 @@ from torch.autograd import Function
|
|||
|
||||
# Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the
|
||||
# "cpu" alternative code and uses that instead for compatibility reasons.
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
@ -609,11 +613,7 @@ class Generator(nn.Module):
|
|||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
|
||||
else:
|
||||
return image, None
|
||||
return image, latent
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
|
@ -741,3 +741,14 @@ class Discriminator(nn.Module):
|
|||
out = self.final_linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_model
|
||||
def register_stylegan2_rosinality_gen(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return Generator(**kw)
|
||||
|
||||
@register_model
|
||||
def register_stylegan2_rosinality_disc(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return Discriminator(**kw)
|
||||
|
|
|
@ -13,17 +13,17 @@ import torch
|
|||
def main():
|
||||
split_img = False
|
||||
opt = {}
|
||||
opt['n_thread'] = 3
|
||||
opt['n_thread'] = 5
|
||||
opt['compression_level'] = 95 # JPEG compression quality rating.
|
||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
|
||||
opt['dest'] = 'file'
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped'
|
||||
opt['imgsize'] = 1024
|
||||
opt['bottom_crop'] = .1
|
||||
opt['keep_folder'] = True
|
||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
|
||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\cropped'
|
||||
opt['imgsize'] = 256
|
||||
opt['bottom_crop'] = 0
|
||||
opt['keep_folder'] = False
|
||||
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
|
@ -83,7 +83,7 @@ class TiledDataset(data.Dataset):
|
|||
pts = os.path.split(pts[0])
|
||||
output_folder = osp.join(self.opt['save_folder'], pts[-1])
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -1,292 +0,0 @@
|
|||
# Converts from Tensorflow Stylegan2 weights to weights used by this model.
|
||||
# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py
|
||||
# Adapted to lucidrains' Stylegan implementation.
|
||||
#
|
||||
# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo.
|
||||
|
||||
# THIS DOES NOT CURRENTLY WORK.
|
||||
# It does transfer all weights from the stylegan model to the lucidrains one, but does not produce correct results.
|
||||
# The rosinality script this was stolen from has some "odd" intracacies that may be at cause for this: for example
|
||||
# weight "flipping" in the conv layers which I do not understand. It may also be because I botched some of the mods
|
||||
# required to make the lucidrains implementation conformant. I'll (maybe) get back to this some day.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import math
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import utils
|
||||
|
||||
|
||||
# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter.
|
||||
from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent
|
||||
|
||||
|
||||
def get_vars(vars, source_name):
|
||||
net_name = source_name.split('/')[0]
|
||||
vars_as_tuple_list = vars[net_name]['variables']
|
||||
result_vars = {}
|
||||
for t in vars_as_tuple_list:
|
||||
result_vars[t[0]] = t[1]
|
||||
return result_vars, source_name.replace(net_name + "/", "")
|
||||
|
||||
def get_vars_direct(vars, source_name):
|
||||
v, n = get_vars(vars, source_name)
|
||||
return v[n]
|
||||
|
||||
|
||||
def convert_modconv(vars, source_name, target_name, flip=False, numeral=1):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
noise = vars[source_name + "/noise_strength"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)),
|
||||
f"to_style{numeral}.weight": mod_weight.transpose((1, 0)),
|
||||
f"to_style{numeral}.bias": mod_bias + 1,
|
||||
f"noise{numeral}_scale": np.array([noise]),
|
||||
f"activation{numeral}.bias": bias,
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
if flip:
|
||||
dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip(
|
||||
dic_torch[target_name + f".conv{numeral}.weight"], [2, 3]
|
||||
)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
|
||||
dic = {"weight": weight.transpose((3, 2, 0, 1))}
|
||||
|
||||
if bias:
|
||||
dic["bias"] = vars[source_name + "/bias"]
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
|
||||
|
||||
if bias:
|
||||
dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_torgb(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
mod_weight = vars[source_name + "/mod_weight"]
|
||||
mod_bias = vars[source_name + "/mod_bias"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {
|
||||
"conv.weight": weight.transpose((3, 2, 0, 1)),
|
||||
"to_style.weight": mod_weight.transpose((1, 0)),
|
||||
"to_style.bias": mod_bias + 1,
|
||||
# "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this?
|
||||
}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def convert_dense(vars, source_name, target_name):
|
||||
vars, source_name = get_vars(vars, source_name)
|
||||
weight = vars[source_name + "/weight"]
|
||||
bias = vars[source_name + "/bias"]
|
||||
|
||||
dic = {"weight": weight.transpose((1, 0)), "bias": bias}
|
||||
|
||||
dic_torch = {}
|
||||
|
||||
for k, v in dic.items():
|
||||
dic_torch[target_name + "." + k] = torch.from_numpy(v)
|
||||
|
||||
return dic_torch
|
||||
|
||||
|
||||
def update(state_dict, new, strict=True):
|
||||
|
||||
for k, v in new.items():
|
||||
if strict:
|
||||
if k not in state_dict:
|
||||
raise KeyError(k + " is not found")
|
||||
|
||||
if v.shape != state_dict[k].shape:
|
||||
raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
|
||||
|
||||
state_dict[k] = v
|
||||
|
||||
|
||||
def discriminator_fill_statedict(statedict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
||||
|
||||
conv_i = 1
|
||||
|
||||
for i in range(log_size - 2, 0, -1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||
),
|
||||
)
|
||||
conv_i += 1
|
||||
|
||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
||||
|
||||
return statedict
|
||||
|
||||
|
||||
def fill_statedict(state_dict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
for i in range(8):
|
||||
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}"))
|
||||
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
"gen.initial_block": torch.from_numpy(
|
||||
get_vars_direct(vars, "G_synthesis/4x4/Const/const")
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
for i in range(log_size - 1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
state_dict,
|
||||
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"),
|
||||
)
|
||||
|
||||
update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1))
|
||||
|
||||
for i in range(1, log_size - 1):
|
||||
reso = 4 * 2 ** i
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars,
|
||||
f"G_synthesis/{reso}x{reso}/Conv0_up",
|
||||
f"gen.blocks.{i}",
|
||||
#flip=True, # TODO: why??
|
||||
numeral=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
state_dict,
|
||||
convert_modconv(
|
||||
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2
|
||||
),
|
||||
)
|
||||
|
||||
'''
|
||||
TODO: consider porting this, though I dont think it is necessary.
|
||||
for i in range(0, (log_size - 2) * 2 + 1):
|
||||
update(
|
||||
state_dict,
|
||||
{
|
||||
f"noises.noise_{i}": torch.from_numpy(
|
||||
get_vars_direct(vars, f"G_synthesis/noise{i}")
|
||||
)
|
||||
},
|
||||
)
|
||||
'''
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tensorflow to pytorch model checkpoint converter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen", action="store_true", help="convert the generator weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel_multiplier",
|
||||
type=int,
|
||||
default=2,
|
||||
help="channel multiplier factor. config-f = 2, else = 1",
|
||||
)
|
||||
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
sys.path.append('scripts\\stylegan2')
|
||||
|
||||
import dnnlib
|
||||
from dnnlib.tflib.network import generator, gen_ema
|
||||
|
||||
with open(args.path, "rb") as f:
|
||||
pickle.load(f)
|
||||
|
||||
# Weight names are ordered by size. The last name will be something like '1024x1024/<blah>'. We just need to grab that first number.
|
||||
size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0])
|
||||
|
||||
g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8)
|
||||
state_dict = g.state_dict()
|
||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
||||
|
||||
g.load_state_dict(state_dict, strict=True)
|
||||
|
||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
||||
|
||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
||||
|
||||
if args.gen:
|
||||
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||
g_train_state = g_train.state_dict()
|
||||
g_train_state = fill_statedict(g_train_state, generator, size)
|
||||
ckpt["g"] = g_train_state
|
||||
|
||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||
torch.save(ckpt, name + ".pt")
|
||||
|
||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||
n_sample = batch_size.get(size, 25)
|
||||
|
||||
g = g.to(device)
|
||||
|
||||
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
|
||||
|
||||
with torch.no_grad():
|
||||
img_pt, _ = g(8)
|
||||
|
||||
utils.save_image(
|
||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
||||
)
|
|
@ -123,7 +123,7 @@ def update(state_dict, new):
|
|||
def discriminator_fill_statedict(statedict, vars, size):
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
|
||||
update(statedict, convert_conv(vars, f"D/{size}x{size}/FromRGB", "convs.0"))
|
||||
|
||||
conv_i = 1
|
||||
|
||||
|
@ -131,25 +131,25 @@ def discriminator_fill_statedict(statedict, vars, size):
|
|||
reso = 4 * 2 ** i
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||
convert_conv(vars, f"D/{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||
vars, f"D/{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
|
||||
),
|
||||
)
|
||||
update(
|
||||
statedict,
|
||||
convert_conv(
|
||||
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||
vars, f"D/{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
|
||||
),
|
||||
)
|
||||
conv_i += 1
|
||||
|
||||
update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
|
||||
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
|
||||
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
|
||||
update(statedict, convert_conv(vars, f"D/4x4/Conv", "final_conv"))
|
||||
update(statedict, convert_dense(vars, f"D/4x4/Dense0", "final_linear.0"))
|
||||
update(statedict, convert_dense(vars, f"D/Output", "final_linear.1"))
|
||||
|
||||
return statedict
|
||||
|
||||
|
@ -249,9 +249,14 @@ if __name__ == "__main__":
|
|||
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
||||
state_dict = g.state_dict()
|
||||
state_dict = fill_statedict(state_dict, gen_ema, size)
|
||||
|
||||
g.load_state_dict(state_dict, strict=True)
|
||||
|
||||
d = Discriminator(size, args.channel_multiplier)
|
||||
dstate_dict = d.state_dict()
|
||||
dstate_dict = discriminator_fill_statedict(dstate_dict, discriminator, size)
|
||||
d.load_state_dict(dstate_dict, strict=True)
|
||||
|
||||
|
||||
latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg"))
|
||||
|
||||
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
|
||||
|
@ -269,14 +274,16 @@ if __name__ == "__main__":
|
|||
ckpt["d"] = d_state
|
||||
|
||||
name = os.path.splitext(os.path.basename(args.path))[0]
|
||||
torch.save(state_dict, name + ".pth")
|
||||
torch.save(state_dict, f"{name}_gen.pth")
|
||||
torch.save(dstate_dict, f"{name}_disc.pth")
|
||||
|
||||
batch_size = {256: 16, 512: 9, 1024: 4}
|
||||
n_sample = batch_size.get(size, 25)
|
||||
|
||||
g = g.to(device)
|
||||
d = d.to(device)
|
||||
|
||||
z = np.random.RandomState(5).randn(n_sample, 512).astype("float32")
|
||||
z = np.random.RandomState(1).randn(n_sample, 512).astype("float32")
|
||||
|
||||
with torch.no_grad():
|
||||
img_pt, _ = g(
|
||||
|
@ -285,6 +292,8 @@ if __name__ == "__main__":
|
|||
truncation_latent=latent_avg.to(device),
|
||||
randomize_noise=False,
|
||||
)
|
||||
disc = d(img_pt)
|
||||
print(disc)
|
||||
|
||||
utils.save_image(
|
||||
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
|
||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_vqvae3_stage1.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -13,6 +13,8 @@ from trainer.steps import ConfigurableStep
|
|||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
|
@ -253,6 +255,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
denorm = opt_get(self.opt, ['logger', 'denormalize'], False)
|
||||
denorm_range = tuple(opt_get(self.opt, ['logger', 'denormalize_range'], None))
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
for v in self.opt['logger']['visuals']:
|
||||
if v not in state.keys():
|
||||
|
@ -264,12 +268,12 @@ class ExtensibleTrainer(BaseModel):
|
|||
if rdbgv.shape[1] > 3:
|
||||
rdbgv = rdbgv[:, :3, :, :]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)), normalize=denorm, range=denorm_range)
|
||||
else:
|
||||
if dbgv.shape[1] > 3:
|
||||
dbgv = dbgv[:,:3,:,:]
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)), normalize=denorm, range=denorm_range)
|
||||
# Some models have their own specific visual debug routines.
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "visual_dbg"):
|
||||
|
|
|
@ -8,6 +8,9 @@ from pytorch_fid import fid_score
|
|||
|
||||
|
||||
# Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results.
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class StyleTransferEvaluator(evaluator.Evaluator):
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env)
|
||||
|
@ -16,13 +19,18 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
|||
self.im_sz = opt_eval['image_size']
|
||||
self.fid_real_samples = opt_eval['real_fid_path']
|
||||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
|
||||
self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input.
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
counter = 0
|
||||
for i in range(self.batches_per_eval):
|
||||
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
if self.noise_type == 'imgnoise':
|
||||
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
elif self.noise_type == 'stylenoise':
|
||||
batch = [torch.randn(self.batch_sz, self.latent_dim).to(self.env['device'])]
|
||||
gen = self.model(batch)
|
||||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch.cuda.amp import autocast
|
|||
|
||||
from trainer.inject import Injector
|
||||
from trainer.losses import extract_params_from_state
|
||||
from utils.util import opt_get
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
|
||||
|
@ -386,3 +387,19 @@ class RandomCropInjector(Injector):
|
|||
def forward(self, state):
|
||||
return {self.output: self.operator(state[self.input])}
|
||||
|
||||
|
||||
class Stylegan2NoiseInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.mix_prob = opt_get(opt, ['mix_probability'], .9)
|
||||
self.latent_dim = opt_get(opt, ['latent_dim'], 512)
|
||||
|
||||
def make_noise(self, batch, latent_dim, n_noise, device):
|
||||
return torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
||||
|
||||
def forward(self, state):
|
||||
i = state[self.input]
|
||||
if self.mix_prob > 0 and random.random() < self.mix_prob:
|
||||
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 2, i.device)}
|
||||
else:
|
||||
return {self.output: self.make_noise(i.shape[0], self.latent_dim, 1, i.device)}
|
Loading…
Reference in New Issue
Block a user