stylegan2 in ml art school!

This commit is contained in:
James Betker 2020-11-12 15:42:05 -07:00
parent fd97573085
commit 2d3449d7a5
7 changed files with 834 additions and 3 deletions

View File

@ -41,6 +41,8 @@ def create_dataset(dataset_opt):
from data.multiscale_dataset import MultiScaleDataset as D
elif mode == 'paired_frame':
from data.paired_frame_dataset import PairedFrameDataset as D
elif mode == 'stylegan2':
from data.stylegan2_dataset import Stylegan2Dataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

View File

@ -0,0 +1,101 @@
from functools import partial
from random import random
import torch
import torchvision
from PIL import Image
from torch.utils import data
from torchvision import transforms
import torch.nn as nn
from pathlib import Path
from models.archs.stylegan2 import exists
def convert_transparent_to_rgb(image):
if image.mode != 'RGB':
return image.convert('RGB')
return image
def convert_rgb_to_transparent(image):
if image.mode != 'RGBA':
return image.convert('RGBA')
return image
def resize_to_minimum_size(min_size, image):
if max(*image.size) < min_size:
return torchvision.transforms.functional.resize(image, min_size)
return image
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else=lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)
class expand_greyscale(object):
def __init__(self, transparent):
self.transparent = transparent
def __call__(self, tensor):
channels = tensor.shape[0]
num_target_channels = 4 if self.transparent else 3
if channels == num_target_channels:
return tensor
alpha = None
if channels == 1:
color = tensor.expand(3, -1, -1)
elif channels == 2:
color = tensor[:1].expand(3, -1, -1)
alpha = tensor[1:]
else:
raise Exception(f'image with invalid number of channels given {channels}')
if not exists(alpha) and self.transparent:
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
return color if not self.transparent else torch.cat((color, alpha))
class Stylegan2Dataset(data.Dataset):
def __init__(self, opt):
super().__init__()
EXTS = ['jpg', 'jpeg', 'png']
self.folder = opt['path']
self.image_size = opt['target_size']
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
aug_prob = opt['aug_prob']
transparent = opt['transparent'] if 'transparent' in opt.keys() else False
assert len(self.paths) > 0, f'No images were found in {self.folder} for training'
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
num_channels = 3 if not transparent else 4
self.transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, self.image_size)),
transforms.Resize(self.image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)),
transforms.CenterCrop(self.image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_greyscale(transparent))
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
img = self.transform(img)
return {'LQ': img, 'GT': img, 'GT_path': str(path)}

View File

@ -0,0 +1,651 @@
import math
import multiprocessing
from contextlib import contextmanager, ExitStack
from functools import partial
from math import log2
from random import random
import torch
import torch.nn.functional as F
from kornia.filters import filter2D
from linear_attention_transformer import ImageLinearAttention
from torch import nn
from torch.autograd import grad as torch_grad
from vector_quantize_pytorch import VectorQuantize
from utils.util import checkpoint
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
num_cores = multiprocessing.cpu_count()
# constants
EPS = 1e-8
CALC_FID_NUM_IMAGES = 12800
# helper classes
def DiffAugment(x, types=[]):
for p in types:
for f in AUGMENT_FNS[p]:
x = f(x)
return x.contiguous()
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
class NanException(Exception):
pass
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if not exists(old):
return new
return old * self.beta + (1 - self.beta) * new
class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.shape[0], -1)
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
class PermuteToFrom(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = x.permute(0, 2, 3, 1)
out, loss = self.fn(x)
out = out.permute(0, 3, 1, 2)
return out, loss
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2D(x, f, normalized=True)
# one layer of self-attention and feedforward, for images
attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])
# helpers
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def combine_contexts(contexts):
@contextmanager
def multi_contexts():
with ExitStack() as stack:
yield [stack.enter_context(ctx()) for ctx in contexts]
return multi_contexts
def default(value, d):
return value if exists(value) else d
def cycle(iterable):
while True:
for i in iterable:
yield i
def cast_list(el):
return el if isinstance(el, list) else [el]
def is_empty(t):
if isinstance(t, torch.Tensor):
return t.nelement() == 0
return not exists(t)
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
if is_ddp:
num_no_syncs = gradient_accumulate_every - 1
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
tail = [null_context]
contexts = head + tail
else:
contexts = [null_context] * gradient_accumulate_every
for context in contexts:
with context():
yield
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
def gradient_penalty(images, output, weight=10):
batch_size = images.shape[0]
gradients = torch_grad(outputs=output, inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
def calc_pl_lengths(styles, images):
device = images.device
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape, device=device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
def image_noise(n, im_size, device):
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
def leaky_relu(p=0.2):
return nn.LeakyReLU(p, inplace=True)
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
# augmentations
def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))
class StyleGan2Augmentor(nn.Module):
def __init__(self, D, image_size, types, prob):
super().__init__()
self.D = D
self.prob = prob
self.types = types
def forward(self, images, detach=False):
if random() < self.prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=self.types)
if detach:
images = images.detach()
return self.D(images)
# stylegan2 classes
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim))
self.lr_mul = lr_mul
def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleVectorizer(nn.Module):
def __init__(self, emb, depth, lr_mul=0.1):
super().__init__()
layers = []
for i in range(depth):
layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()])
self.net = nn.Sequential(*layers)
def forward(self, x):
x = F.normalize(x, dim=1)
return self.net(x)
class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
Blur()
) if upsample else None
def forward(self, x, prev_rgb, istyle):
b, c, h, w = x.shape
style = self.to_style(istyle)
x = self.conv(x, style)
if exists(prev_rgb):
x = x + prev_rgb
if exists(self.upsample):
x = self.upsample(x)
return x
class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
super().__init__()
self.filters = out_chan
self.demod = demod
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def _get_same_padding(self, size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
def forward(self, x, y):
b, c, h, w = x.shape
w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
weights = weights * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.filters, *ws)
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
x = F.conv2d(x, weights, padding=padding, groups=b)
x = x.reshape(-1, self.filters, h, w)
return x
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.to_style1 = nn.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
def forward(self, x, prev_rgb, istyle, inoise):
if exists(self.upsample):
x = self.upsample(x)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
fmap_max=512):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]
in_out_pairs = zip(filters[:-1], filters[1:])
self.no_const = no_const
if no_const:
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False)
else:
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4)))
self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1)
self.blocks = nn.ModuleList([])
self.attns = nn.ModuleList([])
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
num_layer = self.num_layers - ind
attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None
self.attns.append(attn_fn)
block = GeneratorBlock(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
)
self.blocks.append(block)
def forward(self, styles, input_noise):
batch_size = styles.shape[0]
image_size = self.image_size
if self.no_const:
avg_style = styles.mean(dim=1)[:, :, None, None]
x = self.to_initial_block(avg_style)
else:
x = self.initial_block.expand(batch_size, -1, -1, -1)
rgb = None
styles = styles.transpose(0, 1)
x = self.initial_conv(x)
for style, block, attn in zip(styles, self.blocks, self.attns):
if exists(attn):
x = attn(x)
x, rgb = checkpoint(block, x, rgb, style, input_noise)
return rgb
# Wrapper that combines style vectorizer with the actual generator.
class StyleGan2GeneratorWithLatent(nn.Module):
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512):
super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max)
self.mixed_prob = .9
self._init_weights()
def noise(self, n, latent_dim, device):
return torch.randn(n, latent_dim).cuda(device)
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
# b,f,h,w.
def forward(self, x):
b, f, h, w = x.shape
style = self.noise(b*2, self.gen.latent_dim, x.device)
w = self.vectorizer(style)
# Randomly distribute styles across layers
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
for j in range(b):
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
w_styles[j] = w_styles[j*2]
else:
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
w_styles = w_styles[:b]
# The underlying model expects the noise as b,h,w,1. Make it so.
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
nn.init.zeros_(block.to_noise1.weight)
nn.init.zeros_(block.to_noise2.weight)
nn.init.zeros_(block.to_noise1.bias)
nn.init.zeros_(block.to_noise2.bias)
class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512):
super().__init__()
num_layers = int(log2(image_size) - 1)
num_init_filters = 3 if not transparent else 4
blocks = []
filters = [num_init_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
chan_in_out = list(zip(filters[:-1], filters[1:]))
blocks = []
attn_blocks = []
quantize_blocks = []
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last)
blocks.append(block)
attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None
attn_blocks.append(attn_fn)
quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None
quantize_blocks.append(quantize_fn)
self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
self.quantize_blocks = nn.ModuleList(quantize_blocks)
chan_last = filters[-1]
latent_dim = 2 * 2 * chan_last
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten()
self.to_logit = nn.Linear(latent_dim, 1)
self._init_weights()
def forward(self, x):
b, *_ = x.shape
quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
x = checkpoint(block, x)
if exists(attn_block):
x = attn_block(x)
if exists(q_block):
x, _, loss = q_block(x)
quantize_loss += loss
x = self.final_conv(x)
x = self.flatten(x)
x = self.to_logit(x)
if exists(q_block):
return x.squeeze(), quantize_loss
else:
return x.squeeze()
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

View File

@ -22,6 +22,7 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.pyramid_arch import BasicResamplingFlowNet
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
from models.archs.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor
from models.archs.teco_resgen import TecoGen
logger = logging.getLogger('base')
@ -131,6 +132,9 @@ def define_G(opt, net_key='network_G', scale=None):
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
elif which_model == "linear_latent_estimator":
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
elif which_model == 'stylegan2':
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
@ -189,6 +193,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "pyramid_disc":
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
elif which_model == "stylegan2_discriminator":
disc = StyleGan2Discriminator(image_size=opt_net['image_size'])
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD

View File

@ -147,6 +147,7 @@ class ScheduledScalarInjector(Injector):
class AddNoiseInjector(Injector):
def __init__(self, opt, env):
super(AddNoiseInjector, self).__init__(opt, env)
self.mode = opt['mode'] if 'mode' in opt.keys() else 'normal'
def forward(self, state):
# Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector).
@ -155,7 +156,11 @@ class AddNoiseInjector(Injector):
else:
scale = self.opt['scale']
noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale
ref = state[self.opt['in']]
if self.mode == 'normal':
noise = torch.randn_like(ref) * scale
elif self.mode == 'uniform':
noise = torch.FloatTensor(ref.shape).uniform_(0.0, scale).to(ref.device)
return {self.opt['out']: state[self.opt['in']] + noise}

View File

@ -6,7 +6,8 @@ from models.networks import define_F
from models.loss import GANLoss
import random
import functools
import torchvision
import torch.nn.functional as F
import numpy as np
def create_loss(opt_loss, env):
@ -36,6 +37,10 @@ def create_loss(opt_loss, env):
return RecurrentLoss(opt_loss, env)
elif type == 'for_element':
return ForElementLoss(opt_loss, env)
elif type == 'stylegan2_divergence':
return StyleGan2DivergenceLoss(opt_loss, env)
elif type == 'stylegan2_pathlen':
return StyleGan2PathLengthLoss(opt_loss, env)
else:
raise NotImplementedError
@ -482,3 +487,62 @@ class ForElementLoss(ConfigurableLoss):
def clear_metrics(self):
self.loss.clear_metrics()
class StyleGan2DivergenceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.real = opt['real']
self.fake = opt['fake']
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
def forward(self, net, state):
D = self.env['discriminators'][self.discriminator]
fake = D(state[self.fake])
if self.for_gen:
return fake.mean()
else:
real_input = state[self.real].requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
gp = 0
if self.env['step'] % self.gp_frequency == 0:
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan2 import gradient_penalty
gp = gradient_penalty(real_input, real)
self.last_gp_loss = gp.clone().detach().item()
self.metrics.append(("gradient_penalty", gp))
real_input.requires_grad_(requires_grad=False)
return divergence_loss + gp
class StyleGan2PathLengthLoss(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.w_styles = opt['w_styles']
self.gen = opt['gen']
self.pl_mean = None
from models.archs.stylegan2 import EMA
self.pl_length_ma = EMA(.99)
def forward(self, net, state):
w_styles = state[self.w_styles]
gen = state[self.gen]
from models.archs.stylegan2 import calc_pl_lengths
pl_lengths = calc_pl_lengths(w_styles, gen)
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
from models.archs.stylegan2 import is_empty
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
return pl_loss
else:
print("Path length loss returned NaN!")
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
return 0

View File

@ -152,7 +152,8 @@ class ConfigurableStep(Module):
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
# be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before']:
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
continue
l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name]