From 2d3449d7a58881b0604dcf750a8709ae1a974d4c Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 12 Nov 2020 15:42:05 -0700 Subject: [PATCH] stylegan2 in ml art school! --- codes/data/__init__.py | 2 + codes/data/stylegan2_dataset.py | 101 +++++ codes/models/archs/stylegan2.py | 651 ++++++++++++++++++++++++++++++++ codes/models/networks.py | 7 + codes/models/steps/injectors.py | 7 +- codes/models/steps/losses.py | 66 +++- codes/models/steps/steps.py | 3 +- 7 files changed, 834 insertions(+), 3 deletions(-) create mode 100644 codes/data/stylegan2_dataset.py create mode 100644 codes/models/archs/stylegan2.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 7778c7cb..1e7b65b9 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -41,6 +41,8 @@ def create_dataset(dataset_opt): from data.multiscale_dataset import MultiScaleDataset as D elif mode == 'paired_frame': from data.paired_frame_dataset import PairedFrameDataset as D + elif mode == 'stylegan2': + from data.stylegan2_dataset import Stylegan2Dataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py new file mode 100644 index 00000000..d2e39331 --- /dev/null +++ b/codes/data/stylegan2_dataset.py @@ -0,0 +1,101 @@ +from functools import partial +from random import random + +import torch +import torchvision +from PIL import Image +from torch.utils import data +from torchvision import transforms +import torch.nn as nn +from pathlib import Path + +from models.archs.stylegan2 import exists + + +def convert_transparent_to_rgb(image): + if image.mode != 'RGB': + return image.convert('RGB') + return image + + +def convert_rgb_to_transparent(image): + if image.mode != 'RGBA': + return image.convert('RGBA') + return image + + +def resize_to_minimum_size(min_size, image): + if max(*image.size) < min_size: + return torchvision.transforms.functional.resize(image, min_size) + return image + +class RandomApply(nn.Module): + def __init__(self, prob, fn, fn_else=lambda x: x): + super().__init__() + self.fn = fn + self.fn_else = fn_else + self.prob = prob + + def forward(self, x): + fn = self.fn if random() < self.prob else self.fn_else + return fn(x) + + +class expand_greyscale(object): + def __init__(self, transparent): + self.transparent = transparent + + def __call__(self, tensor): + channels = tensor.shape[0] + num_target_channels = 4 if self.transparent else 3 + + if channels == num_target_channels: + return tensor + + alpha = None + if channels == 1: + color = tensor.expand(3, -1, -1) + elif channels == 2: + color = tensor[:1].expand(3, -1, -1) + alpha = tensor[1:] + else: + raise Exception(f'image with invalid number of channels given {channels}') + + if not exists(alpha) and self.transparent: + alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) + + return color if not self.transparent else torch.cat((color, alpha)) + + +class Stylegan2Dataset(data.Dataset): + def __init__(self, opt): + super().__init__() + EXTS = ['jpg', 'jpeg', 'png'] + self.folder = opt['path'] + self.image_size = opt['target_size'] + self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')] + aug_prob = opt['aug_prob'] + transparent = opt['transparent'] if 'transparent' in opt.keys() else False + assert len(self.paths) > 0, f'No images were found in {self.folder} for training' + + convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent + num_channels = 3 if not transparent else 4 + + self.transform = transforms.Compose([ + transforms.Lambda(convert_image_fn), + transforms.Lambda(partial(resize_to_minimum_size, self.image_size)), + transforms.Resize(self.image_size), + RandomApply(aug_prob, transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), + transforms.CenterCrop(self.image_size)), + transforms.ToTensor(), + transforms.Lambda(expand_greyscale(transparent)) + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + img = self.transform(img) + return {'LQ': img, 'GT': img, 'GT_path': str(path)} diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py new file mode 100644 index 00000000..ec0a2804 --- /dev/null +++ b/codes/models/archs/stylegan2.py @@ -0,0 +1,651 @@ +import math +import multiprocessing +from contextlib import contextmanager, ExitStack +from functools import partial +from math import log2 +from random import random + +import torch +import torch.nn.functional as F +from kornia.filters import filter2D +from linear_attention_transformer import ImageLinearAttention +from torch import nn +from torch.autograd import grad as torch_grad +from vector_quantize_pytorch import VectorQuantize + +from utils.util import checkpoint + +try: + from apex import amp + + APEX_AVAILABLE = True +except: + APEX_AVAILABLE = False + +assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' + +num_cores = multiprocessing.cpu_count() + +# constants + +EPS = 1e-8 +CALC_FID_NUM_IMAGES = 12800 + + +# helper classes + +def DiffAugment(x, types=[]): + for p in types: + for f in AUGMENT_FNS[p]: + x = f(x) + return x.contiguous() + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} + +class NanException(Exception): + pass + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if not exists(old): + return new + return old * self.beta + (1 - self.beta) * new + + +class Flatten(nn.Module): + def forward(self, x): + return x.reshape(x.shape[0], -1) + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class PermuteToFrom(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + out, loss = self.fn(x) + out = out.permute(0, 3, 1, 2) + return out, loss + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer('f', f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f[None, :, None] + return filter2D(x, f, normalized=True) + + +# one layer of self-attention and feedforward, for images + +attn_and_ff = lambda chan: nn.Sequential(*[ + Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), + Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) +]) + + +# helpers + +def exists(val): + return val is not None + + +@contextmanager +def null_context(): + yield + + +def combine_contexts(contexts): + @contextmanager + def multi_contexts(): + with ExitStack() as stack: + yield [stack.enter_context(ctx()) for ctx in contexts] + + return multi_contexts + + +def default(value, d): + return value if exists(value) else d + + +def cycle(iterable): + while True: + for i in iterable: + yield i + + +def cast_list(el): + return el if isinstance(el, list) else [el] + + +def is_empty(t): + if isinstance(t, torch.Tensor): + return t.nelement() == 0 + return not exists(t) + + +def raise_if_nan(t): + if torch.isnan(t): + raise NanException + + +def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): + if is_ddp: + num_no_syncs = gradient_accumulate_every - 1 + head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs + tail = [null_context] + contexts = head + tail + else: + contexts = [null_context] * gradient_accumulate_every + + for context in contexts: + with context(): + yield + + +def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): + if fp16: + with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: + scaled_loss.backward(**kwargs) + else: + loss.backward(**kwargs) + + +def gradient_penalty(images, output, weight=10): + batch_size = images.shape[0] + gradients = torch_grad(outputs=output, inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.reshape(batch_size, -1) + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +def calc_pl_lengths(styles, images): + device = images.device + num_pixels = images.shape[2] * images.shape[3] + pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) + outputs = (images * pl_noise).sum() + + pl_grads = torch_grad(outputs=outputs, inputs=styles, + grad_outputs=torch.ones(outputs.shape, device=device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() + + +def image_noise(n, im_size, device): + return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) + + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p, inplace=True) + + +def evaluate_in_chunks(max_batch_size, model, *args): + split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) + chunked_outputs = [model(*i) for i in split_args] + if len(chunked_outputs) == 1: + return chunked_outputs[0] + return torch.cat(chunked_outputs, dim=0) + + +def set_requires_grad(model, bool): + for p in model.parameters(): + p.requires_grad = bool + + +def slerp(val, low, high): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + +# augmentations + +def random_hflip(tensor, prob): + if prob > random(): + return tensor + return torch.flip(tensor, dims=(3,)) + + +class StyleGan2Augmentor(nn.Module): + def __init__(self, D, image_size, types, prob): + super().__init__() + self.D = D + self.prob = prob + self.types = types + + def forward(self, images, detach=False): + if random() < self.prob: + images = random_hflip(images, prob=0.5) + images = DiffAugment(images, types=self.types) + + if detach: + images = images.detach() + + return self.D(images) + + +# stylegan2 classes + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim)) + + self.lr_mul = lr_mul + + def forward(self, input): + return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) + + +class StyleVectorizer(nn.Module): + def __init__(self, emb, depth, lr_mul=0.1): + super().__init__() + + layers = [] + for i in range(depth): + layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()]) + + self.net = nn.Sequential(*layers) + + def forward(self, x): + x = F.normalize(x, dim=1) + return self.net(x) + + +class RGBBlock(nn.Module): + def __init__(self, latent_dim, input_channel, upsample, rgba=False): + super().__init__() + self.input_channel = input_channel + self.to_style = nn.Linear(latent_dim, input_channel) + + out_filters = 3 if not rgba else 4 + self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) + + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + Blur() + ) if upsample else None + + def forward(self, x, prev_rgb, istyle): + b, c, h, w = x.shape + style = self.to_style(istyle) + x = self.conv(x, style) + + if exists(prev_rgb): + x = x + prev_rgb + + if exists(self.upsample): + x = self.upsample(x) + + return x + + +class Conv2DMod(nn.Module): + def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): + super().__init__() + self.filters = out_chan + self.demod = demod + self.kernel = kernel + self.stride = stride + self.dilation = dilation + self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) + nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + def _get_same_padding(self, size, kernel, dilation, stride): + return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 + + def forward(self, x, y): + b, c, h, w = x.shape + + w1 = y[:, None, :, None, None] + w2 = self.weight[None, :, :, :, :] + weights = w2 * (w1 + 1) + + if self.demod: + d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) + weights = weights * d + + x = x.reshape(1, -1, h, w) + + _, _, *ws = weights.shape + weights = weights.reshape(b * self.filters, *ws) + + padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) + x = F.conv2d(x, weights, padding=padding, groups=b) + + x = x.reshape(-1, self.filters, h, w) + return x + + +class GeneratorBlock(nn.Module): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + + self.to_style1 = nn.Linear(latent_dim, input_channels) + self.to_noise1 = nn.Linear(1, filters) + self.conv1 = Conv2DMod(input_channels, filters, 3) + + self.to_style2 = nn.Linear(latent_dim, filters) + self.to_noise2 = nn.Linear(1, filters) + self.conv2 = Conv2DMod(filters, filters, 3) + + self.activation = leaky_relu() + self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) + + def forward(self, x, prev_rgb, istyle, inoise): + if exists(self.upsample): + x = self.upsample(x) + + inoise = inoise[:, :x.shape[2], :x.shape[3], :] + noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) + noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) + + style1 = self.to_style1(istyle) + x = self.conv1(x, style1) + x = self.activation(x + noise1) + + style2 = self.to_style2(istyle) + x = self.conv2(x, style2) + x = self.activation(x + noise2) + + rgb = self.to_rgb(x, prev_rgb, istyle) + return x, rgb + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu() + ) + + self.downsample = nn.Sequential( + Blur(), + nn.Conv2d(filters, filters, 3, padding=1, stride=2) + ) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + +class Generator(nn.Module): + def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, + fmap_max=512): + super().__init__() + self.image_size = image_size + self.latent_dim = latent_dim + self.num_layers = int(log2(image_size) - 1) + + filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] + + set_fmap_max = partial(min, fmap_max) + filters = list(map(set_fmap_max, filters)) + init_channels = filters[0] + filters = [init_channels, *filters] + + in_out_pairs = zip(filters[:-1], filters[1:]) + self.no_const = no_const + + if no_const: + self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) + else: + self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) + + self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1) + self.blocks = nn.ModuleList([]) + self.attns = nn.ModuleList([]) + + for ind, (in_chan, out_chan) in enumerate(in_out_pairs): + not_first = ind != 0 + not_last = ind != (self.num_layers - 1) + num_layer = self.num_layers - ind + + attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None + + self.attns.append(attn_fn) + + block = GeneratorBlock( + latent_dim, + in_chan, + out_chan, + upsample=not_first, + upsample_rgb=not_last, + rgba=transparent + ) + self.blocks.append(block) + + def forward(self, styles, input_noise): + batch_size = styles.shape[0] + image_size = self.image_size + + if self.no_const: + avg_style = styles.mean(dim=1)[:, :, None, None] + x = self.to_initial_block(avg_style) + else: + x = self.initial_block.expand(batch_size, -1, -1, -1) + + rgb = None + styles = styles.transpose(0, 1) + x = self.initial_conv(x) + + for style, block, attn in zip(styles, self.blocks, self.attns): + if exists(attn): + x = attn(x) + x, rgb = checkpoint(block, x, rgb, style, input_noise) + + return rgb + + +# Wrapper that combines style vectorizer with the actual generator. +class StyleGan2GeneratorWithLatent(nn.Module): + def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512): + super().__init__() + self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) + self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max) + self.mixed_prob = .9 + self._init_weights() + + def noise(self, n, latent_dim, device): + return torch.randn(n, latent_dim).cuda(device) + + # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format: + # b,f,h,w. + def forward(self, x): + b, f, h, w = x.shape + style = self.noise(b*2, self.gen.latent_dim, x.device) + w = self.vectorizer(style) + + # Randomly distribute styles across layers + w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() + for j in range(b): + cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) + if cutoff == self.gen.num_layers or random() > self.mixed_prob: + w_styles[j] = w_styles[j*2] + else: + w_styles[j, :cutoff] = w_styles[j*2, :cutoff] + w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] + w_styles = w_styles[:b] + + # The underlying model expects the noise as b,h,w,1. Make it so. + return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles + + def _init_weights(self): + for m in self.modules(): + if type(m) in {nn.Conv2d, nn.Linear}: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + for block in self.gen.blocks: + nn.init.zeros_(block.to_noise1.weight) + nn.init.zeros_(block.to_noise2.weight) + nn.init.zeros_(block.to_noise1.bias) + nn.init.zeros_(block.to_noise2.bias) + + +class StyleGan2Discriminator(nn.Module): + def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], + transparent=False, fmap_max=512): + super().__init__() + num_layers = int(log2(image_size) - 1) + num_init_filters = 3 if not transparent else 4 + + blocks = [] + filters = [num_init_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] + + set_fmap_max = partial(min, fmap_max) + filters = list(map(set_fmap_max, filters)) + chan_in_out = list(zip(filters[:-1], filters[1:])) + + blocks = [] + attn_blocks = [] + quantize_blocks = [] + + for ind, (in_chan, out_chan) in enumerate(chan_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(chan_in_out) - 1) + + block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) + blocks.append(block) + + attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None + + attn_blocks.append(attn_fn) + + quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None + quantize_blocks.append(quantize_fn) + + self.blocks = nn.ModuleList(blocks) + self.attn_blocks = nn.ModuleList(attn_blocks) + self.quantize_blocks = nn.ModuleList(quantize_blocks) + + chan_last = filters[-1] + latent_dim = 2 * 2 * chan_last + + self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) + self.flatten = Flatten() + self.to_logit = nn.Linear(latent_dim, 1) + + self._init_weights() + + def forward(self, x): + b, *_ = x.shape + + quantize_loss = torch.zeros(1).to(x) + + for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): + x = checkpoint(block, x) + + if exists(attn_block): + x = attn_block(x) + + if exists(q_block): + x, _, loss = q_block(x) + quantize_loss += loss + + x = self.final_conv(x) + x = self.flatten(x) + x = self.to_logit(x) + if exists(q_block): + return x.squeeze(), quantize_loss + else: + return x.squeeze() + + def _init_weights(self): + for m in self.modules(): + if type(m) in {nn.Conv2d, nn.Linear}: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index f1ceff7a..4c6a26ea 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -22,6 +22,7 @@ from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator from models.archs.pyramid_arch import BasicResamplingFlowNet from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2 +from models.archs.stylegan2 import StyleGan2GeneratorWithLatent, StyleGan2Discriminator, StyleGan2Augmentor from models.archs.teco_resgen import TecoGen logger = logging.getLogger('base') @@ -131,6 +132,9 @@ def define_G(opt, net_key='network_G', scale=None): netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite) elif which_model == "linear_latent_estimator": netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf']) + elif which_model == 'stylegan2': + netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], + style_depth=opt_net['style_depth']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG @@ -189,6 +193,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == "pyramid_disc": netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf']) + elif which_model == "stylegan2_discriminator": + disc = StyleGan2Discriminator(image_size=opt_net['image_size']) + netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index c5e592fa..272e7ec9 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -147,6 +147,7 @@ class ScheduledScalarInjector(Injector): class AddNoiseInjector(Injector): def __init__(self, opt, env): super(AddNoiseInjector, self).__init__(opt, env) + self.mode = opt['mode'] if 'mode' in opt.keys() else 'normal' def forward(self, state): # Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector). @@ -155,7 +156,11 @@ class AddNoiseInjector(Injector): else: scale = self.opt['scale'] - noise = torch.randn_like(state[self.opt['in']], device=self.env['device']) * scale + ref = state[self.opt['in']] + if self.mode == 'normal': + noise = torch.randn_like(ref) * scale + elif self.mode == 'uniform': + noise = torch.FloatTensor(ref.shape).uniform_(0.0, scale).to(ref.device) return {self.opt['out']: state[self.opt['in']] + noise} diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 4ef123b3..558e18f5 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -6,7 +6,8 @@ from models.networks import define_F from models.loss import GANLoss import random import functools -import torchvision +import torch.nn.functional as F +import numpy as np def create_loss(opt_loss, env): @@ -36,6 +37,10 @@ def create_loss(opt_loss, env): return RecurrentLoss(opt_loss, env) elif type == 'for_element': return ForElementLoss(opt_loss, env) + elif type == 'stylegan2_divergence': + return StyleGan2DivergenceLoss(opt_loss, env) + elif type == 'stylegan2_pathlen': + return StyleGan2PathLengthLoss(opt_loss, env) else: raise NotImplementedError @@ -482,3 +487,62 @@ class ForElementLoss(ConfigurableLoss): def clear_metrics(self): self.loss.clear_metrics() + + +class StyleGan2DivergenceLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.real = opt['real'] + self.fake = opt['fake'] + self.discriminator = opt['discriminator'] + self.for_gen = opt['gen_loss'] + self.gp_frequency = opt['gradient_penalty_frequency'] + + def forward(self, net, state): + D = self.env['discriminators'][self.discriminator] + fake = D(state[self.fake]) + if self.for_gen: + return fake.mean() + else: + real_input = state[self.real].requires_grad_() # <-- Needed to compute gradients on the input. + real = D(real_input) + divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() + + gp = 0 + if self.env['step'] % self.gp_frequency == 0: + # Apply gradient penalty. TODO: migrate this elsewhere. + from models.archs.stylegan2 import gradient_penalty + gp = gradient_penalty(real_input, real) + self.last_gp_loss = gp.clone().detach().item() + self.metrics.append(("gradient_penalty", gp)) + + real_input.requires_grad_(requires_grad=False) + return divergence_loss + gp + + +class StyleGan2PathLengthLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.w_styles = opt['w_styles'] + self.gen = opt['gen'] + self.pl_mean = None + from models.archs.stylegan2 import EMA + self.pl_length_ma = EMA(.99) + + def forward(self, net, state): + w_styles = state[self.w_styles] + gen = state[self.gen] + from models.archs.stylegan2 import calc_pl_lengths + pl_lengths = calc_pl_lengths(w_styles, gen) + avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) + + from models.archs.stylegan2 import is_empty + if not is_empty(self.pl_mean): + pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() + if not torch.isnan(pl_loss): + return pl_loss + else: + print("Path length loss returned NaN!") + + self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) + return 0 diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 1eb02939..483d9592 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -152,7 +152,8 @@ class ConfigurableStep(Module): # Some losses only activate after a set number of steps. For example, proto-discriminator losses can # be very disruptive to a generator. if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \ - 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before']: + 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ + 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: continue l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name]