diff --git a/codes/models/archs/stylegan/stylegan2_lucidrains_conformed.py b/codes/models/archs/stylegan/stylegan2_lucidrains_conformed.py new file mode 100644 index 00000000..23d33912 --- /dev/null +++ b/codes/models/archs/stylegan/stylegan2_lucidrains_conformed.py @@ -0,0 +1,847 @@ +import math +import multiprocessing +from contextlib import contextmanager, ExitStack +from functools import partial +from math import log2 +from random import random + +import torch +import torch.nn.functional as F +import trainer.losses as L +import numpy as np + +from kornia.filters import filter2D +from linear_attention_transformer import ImageLinearAttention +from torch import nn +from torch.autograd import grad as torch_grad +from vector_quantize_pytorch import VectorQuantize + +from utils.util import checkpoint + +try: + from apex import amp + + APEX_AVAILABLE = True +except: + APEX_AVAILABLE = False + +assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' + +num_cores = multiprocessing.cpu_count() + +# constants + +EPS = 1e-8 +CALC_FID_NUM_IMAGES = 12800 + + +# helper classes + +def DiffAugment(x, types=[]): + for p in types: + for f in AUGMENT_FNS[p]: + x = f(x) + return x.contiguous() + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} + +class NanException(Exception): + pass + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if not exists(old): + return new + return old * self.beta + (1 - self.beta) * new + + +class Flatten(nn.Module): + def forward(self, x): + return x.reshape(x.shape[0], -1) + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class PermuteToFrom(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + out, loss = self.fn(x) + out = out.permute(0, 3, 1, 2) + return out, loss + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer('f', f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f[None, :, None] + return filter2D(x, f, normalized=True) + + +# one layer of self-attention and feedforward, for images + +attn_and_ff = lambda chan: nn.Sequential(*[ + Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), + Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) +]) + + +# helpers + +def exists(val): + return val is not None + + +@contextmanager +def null_context(): + yield + + +def combine_contexts(contexts): + @contextmanager + def multi_contexts(): + with ExitStack() as stack: + yield [stack.enter_context(ctx()) for ctx in contexts] + + return multi_contexts + + +def default(value, d): + return value if exists(value) else d + + +def cycle(iterable): + while True: + for i in iterable: + yield i + + +def cast_list(el): + return el if isinstance(el, list) else [el] + + +def is_empty(t): + if isinstance(t, torch.Tensor): + return t.nelement() == 0 + return not exists(t) + + +def raise_if_nan(t): + if torch.isnan(t): + raise NanException + + +def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): + if is_ddp: + num_no_syncs = gradient_accumulate_every - 1 + head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs + tail = [null_context] + contexts = head + tail + else: + contexts = [null_context] * gradient_accumulate_every + + for context in contexts: + with context(): + yield + + +def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): + if fp16: + with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: + scaled_loss.backward(**kwargs) + else: + loss.backward(**kwargs) + + +def gradient_penalty(images, output, weight=10): + batch_size = images.shape[0] + gradients = torch_grad(outputs=output, inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.reshape(batch_size, -1) + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +def calc_pl_lengths(styles, images): + device = images.device + num_pixels = images.shape[2] * images.shape[3] + pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) + outputs = (images * pl_noise).sum() + + pl_grads = torch_grad(outputs=outputs, inputs=styles, + grad_outputs=torch.ones(outputs.shape, device=device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() + + +def image_noise(n, im_size, device): + return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) + + +class BiasedLeakyReLU(nn.Module): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return biased_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def biased_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +def evaluate_in_chunks(max_batch_size, model, *args): + split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) + chunked_outputs = [model(*i) for i in split_args] + if len(chunked_outputs) == 1: + return chunked_outputs[0] + return torch.cat(chunked_outputs, dim=0) + + +def set_requires_grad(model, bool): + for p in model.parameters(): + p.requires_grad = bool + + +def slerp(val, low, high): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + +# augmentations + +def random_hflip(tensor, prob): + if prob > random(): + return tensor + return torch.flip(tensor, dims=(3,)) + + +class StyleGan2Augmentor(nn.Module): + def __init__(self, D, image_size, types, prob): + super().__init__() + self.D = D + self.prob = prob + self.types = types + + def forward(self, images, detach=False): + if random() < self.prob: + images = random_hflip(images, prob=0.5) + images = DiffAugment(images, types=self.types) + + if detach: + images = images.detach() + + # Save away for use elsewhere (e.g. unet loss) + self.aug_images = images + + return self.D(images) + + +# stylegan2 classes + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, activation=False): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim)) + + self.lr_mul = lr_mul + self.activation = activation + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.lr_mul) + out = biased_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class StyleVectorizer(nn.Module): + def __init__(self, emb, depth, lr_mul=0.01): + super().__init__() + + layers = [] + for i in range(depth): + layers.extend([EqualLinear(emb, emb, lr_mul, activation=True)]) + + self.net = nn.Sequential(*layers) + + def forward(self, x): + x = F.normalize(x, dim=1) + return self.net(x) + + +class RGBBlock(nn.Module): + def __init__(self, latent_dim, input_channel, upsample, rgba=False): + super().__init__() + self.input_channel = input_channel + self.to_style = nn.Linear(latent_dim, input_channel) + + out_filters = 3 if not rgba else 4 + self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) + + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + Blur() + ) if upsample else None + + def forward(self, x, prev_rgb, istyle): + b, c, h, w = x.shape + style = self.to_style(istyle) + x = self.conv(x, style) + + if exists(prev_rgb): + x = x + prev_rgb + + if exists(self.upsample): + x = self.upsample(x) + + return x + + +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel, style_dim): + super().__init__() + from models.arch_util import ConvGnLelu + self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) + self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + gamma = self.style2scale(style) + beta = self.style2bias(style) + out = self.norm(input) + out = gamma * out + beta + return out + + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + + def forward(self, image, noise): + return image + self.weight * noise + + +class EqualLR: + def __init__(self, name): + self.name = name + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + fan_in = weight.data.size(1) * weight.data[0][0].numel() + + return weight * math.sqrt(2 / fan_in) + + @staticmethod + def apply(module, name): + fn = EqualLR(name) + + weight = getattr(module, name) + del module._parameters[name] + module.register_parameter(name + '_orig', nn.Parameter(weight.data)) + module.register_forward_pre_hook(fn) + + return fn + + def __call__(self, module, input): + weight = self.compute_weight(module) + setattr(module, self.name, weight) + + +def equal_lr(module, name='weight'): + EqualLR.apply(module, name) + + return module + + +class EqualConv2d(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + conv = nn.Conv2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + + def forward(self, input): + return self.conv(input) + + +class Conv2DMod(nn.Module): + def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): + super().__init__() + self.filters = out_chan + self.demod = demod + self.kernel = kernel + self.stride = stride + self.dilation = dilation + self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) + nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + def _get_same_padding(self, size, kernel, dilation, stride): + return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 + + def forward(self, x, y): + b, c, h, w = x.shape + + w1 = y[:, None, :, None, None] + w2 = self.weight[None, :, :, :, :] + weights = w2 * (w1 + 1) + + if self.demod: + d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) + weights = weights * d + + x = x.reshape(1, -1, h, w) + + _, _, *ws = weights.shape + weights = weights.reshape(b * self.filters, *ws) + + padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) + x = F.conv2d(x, weights, padding=padding, groups=b) + + x = x.reshape(-1, self.filters, h, w) + return x + + +class GeneratorBlock(nn.Module): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, initial_block=False): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + + self.to_style1 = nn.Linear(latent_dim, input_channels) + self.noise1_scale = nn.Parameter(torch.full((1,), fill_value=1e-5)) + self.conv1 = Conv2DMod(input_channels, filters, 3) + self.activation1 = BiasedLeakyReLU(filters) + + self.initial_block = initial_block + if not initial_block: + self.to_style2 = nn.Linear(latent_dim, filters) + self.noise2_scale = nn.Parameter(torch.full((1,), fill_value=1e-5)) + self.conv2 = Conv2DMod(filters, filters, 3) + self.activation2 = BiasedLeakyReLU(filters) + + self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) + + def forward(self, x, prev_rgb, istyle, inoise1=None, inoise2=None): + if exists(self.upsample): + x = self.upsample(x) + if inoise1 is None: + b, c, h, w = x.shape + inoise1 = torch.randn((b,1,h,w), device=x.device) + inoise2 = torch.randn((b,1,h,w), device=x.device) # Assume that both are None if one is None. + + noise1 = inoise1 * self.noise1_scale + style1 = self.to_style1(istyle) + x = self.conv1(x, style1) + x = self.activation1(x + noise1) + + if not self.initial_block: + noise2 = inoise2 * self.noise2_scale + style2 = self.to_style2(istyle) + x = self.conv2(x, style2) + x = self.activation2(x + noise2) + + rgb = self.to_rgb(x, prev_rgb, istyle) + return x, rgb + + +class Generator(nn.Module): + def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, + fmap_max=512): + super().__init__() + self.image_size = image_size + self.latent_dim = latent_dim + self.num_layers = int(log2(image_size) - 1) + + filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] + + set_fmap_max = partial(min, fmap_max) + filters = list(map(set_fmap_max, filters)) + init_channels = filters[0] + filters = [init_channels, *filters] + + in_out_pairs = zip(filters[:-1], filters[1:]) + self.no_const = no_const + + if no_const: + self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) + else: + self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) + + self.blocks = nn.ModuleList([]) + self.attns = nn.ModuleList([]) + + for ind, (in_chan, out_chan) in enumerate(in_out_pairs): + not_first = ind != 0 + not_last = ind != (self.num_layers - 1) + num_layer = self.num_layers - ind + + attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None + + self.attns.append(attn_fn) + + block_fn = GeneratorBlock + + block = block_fn( + latent_dim, + in_chan, + out_chan, + upsample=not_first, + upsample_rgb=not_last, + rgba=transparent, + initial_block=(ind == 0) + ) + self.blocks.append(block) + + def forward(self, styles, input_noises): + batch_size = styles.shape[0] + + if self.no_const: + avg_style = styles.mean(dim=1)[:, :, None, None] + x = self.to_initial_block(avg_style) + else: + x = self.initial_block.expand(batch_size, -1, -1, -1) + + rgb = None + styles = styles.transpose(0, 1) + + n = 0 + for style, block, attn in zip(styles, self.blocks, self.attns): + if exists(attn): + x = checkpoint(attn, x) + x, rgb = checkpoint(block, x, rgb, style, input_noises[n], input_noises[n+1]) + n = 1 if n == 0 else n + 2 # The first block only consumes 1 noise, the rest consume 2. + + return rgb + + +# Wrapper that combines style vectorizer with the actual generator. +class StyleGan2GeneratorWithLatent(nn.Module): + def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, + attn_layers=[], no_const=False, fmap_max=512): + super().__init__() + self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) + self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max) + self.mixed_prob = .9 + self._init_weights() + + + def noise(self, n, latent_dim, device): + return torch.randn(n, latent_dim).cuda(device) + + def noise_list(self, n, layers, latent_dim, device): + return [(self.noise(n, latent_dim, device), layers)] + + def mixed_list(self, n, layers, latent_dim, device): + tt = int(torch.rand(()).numpy() * layers) + return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device) + + def latent_to_w(self, style_vectorizer, latent_descr): + return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] + + def styles_def_to_tensor(self, styles_def): + return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) + + # If provided, 'noise' should be a list of tensors that is fed into each input block. + # b=batch_size. + def forward(self, b, noises=None): + if noises is None: + noises = [None] * (len(self.gen.blocks) * 2 - 1) + full_random_latents = True + if full_random_latents: + style = self.noise(b*2, self.gen.latent_dim, next(self.parameters()).device) + w = self.vectorizer(style) + # Randomly distribute styles across layers + w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() + for j in range(b): + cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) + if cutoff == self.gen.num_layers or random() > self.mixed_prob: + w_styles[j] = w_styles[j*2] + else: + w_styles[j, :cutoff] = w_styles[j*2, :cutoff] + w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] + w_styles = w_styles[:b] + else: + get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list + style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device) + w_space = self.latent_to_w(self.vectorizer, style) + w_styles = self.styles_def_to_tensor(w_space) + + return self.gen(w_styles, noises), w_styles + + def _init_weights(self): + for m in self.modules(): + if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu() + ) + + self.downsample = nn.Sequential( + Blur(), + nn.Conv2d(filters, filters, 3, padding=1, stride=2) + ) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + +class StyleGan2Discriminator(nn.Module): + def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], + transparent=False, fmap_max=512, input_filters=3): + super().__init__() + num_layers = int(log2(image_size) - 1) + + blocks = [] + filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] + + set_fmap_max = partial(min, fmap_max) + filters = list(map(set_fmap_max, filters)) + chan_in_out = list(zip(filters[:-1], filters[1:])) + + blocks = [] + attn_blocks = [] + quantize_blocks = [] + + for ind, (in_chan, out_chan) in enumerate(chan_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(chan_in_out) - 1) + + block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) + blocks.append(block) + + attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None + + attn_blocks.append(attn_fn) + + quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None + quantize_blocks.append(quantize_fn) + + self.blocks = nn.ModuleList(blocks) + self.attn_blocks = nn.ModuleList(attn_blocks) + self.quantize_blocks = nn.ModuleList(quantize_blocks) + + chan_last = filters[-1] + latent_dim = 2 * 2 * chan_last + + self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) + self.flatten = Flatten() + self.to_logit = nn.Linear(latent_dim, 1) + + self._init_weights() + + def forward(self, x): + b, *_ = x.shape + + quantize_loss = torch.zeros(1).to(x) + + for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): + x = block(x) + + if exists(attn_block): + x = attn_block(x) + + if exists(q_block): + x, _, loss = q_block(x) + quantize_loss += loss + + x = self.final_conv(x) + x = self.flatten(x) + x = self.to_logit(x) + if exists(q_block): + return x.squeeze(), quantize_loss + else: + return x.squeeze() + + def _init_weights(self): + for m in self.modules(): + if type(m) in {nn.Conv2d, nn.Linear}: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + + +class StyleGan2DivergenceLoss(L.ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.real = opt['real'] + self.fake = opt['fake'] + self.discriminator = opt['discriminator'] + self.for_gen = opt['gen_loss'] + self.gp_frequency = opt['gradient_penalty_frequency'] + self.noise = opt['noise'] if 'noise' in opt.keys() else 0 + + def forward(self, net, state): + real_input = state[self.real] + fake_input = state[self.fake] + if self.noise != 0: + fake_input = fake_input + torch.rand_like(fake_input) * self.noise + real_input = real_input + torch.rand_like(real_input) * self.noise + + D = self.env['discriminators'][self.discriminator] + fake = D(fake_input) + if self.for_gen: + return fake.mean() + else: + real_input.requires_grad_() # <-- Needed to compute gradients on the input. + real = D(real_input) + divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() + + # Apply gradient penalty. TODO: migrate this elsewhere. + if self.env['step'] % self.gp_frequency == 0: + from models.stylegan.stylegan2_lucidrains import gradient_penalty + gp = gradient_penalty(real_input, real) + self.metrics.append(("gradient_penalty", gp.clone().detach())) + divergence_loss = divergence_loss + gp + + real_input.requires_grad_(requires_grad=False) + return divergence_loss + + +class StyleGan2PathLengthLoss(L.ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.w_styles = opt['w_styles'] + self.gen = opt['gen'] + self.pl_mean = None + from models.stylegan.stylegan2_lucidrains import EMA + self.pl_length_ma = EMA(.99) + + def forward(self, net, state): + w_styles = state[self.w_styles] + gen = state[self.gen] + from models.stylegan.stylegan2_lucidrains import calc_pl_lengths + pl_lengths = calc_pl_lengths(w_styles, gen) + avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) + + from models.stylegan.stylegan2_lucidrains import is_empty + if not is_empty(self.pl_mean): + pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() + if not torch.isnan(pl_loss): + return pl_loss + else: + print("Path length loss returned NaN!") + + self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) + return 0 diff --git a/codes/scripts/stylegan2/convert_weights_lucidrains.py b/codes/scripts/stylegan2/convert_weights_lucidrains.py index e69de29b..a63bbd4c 100644 --- a/codes/scripts/stylegan2/convert_weights_lucidrains.py +++ b/codes/scripts/stylegan2/convert_weights_lucidrains.py @@ -0,0 +1,286 @@ +# Converts from Tensorflow Stylegan2 weights to weights used by this model. +# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py +# Adapted to lucidrains' Stylegan implementation. +# +# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo. + +import argparse +import os +import sys +import pickle +import math + +import torch +import numpy as np +from torchvision import utils + + +# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter. +from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent + + +def get_vars(vars, source_name): + net_name = source_name.split('/')[0] + vars_as_tuple_list = vars[net_name]['variables'] + result_vars = {} + for t in vars_as_tuple_list: + result_vars[t[0]] = t[1] + return result_vars, source_name.replace(net_name + "/", "") + +def get_vars_direct(vars, source_name): + v, n = get_vars(vars, source_name) + return v[n] + + +def convert_modconv(vars, source_name, target_name, flip=False, numeral=1): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + mod_weight = vars[source_name + "/mod_weight"] + mod_bias = vars[source_name + "/mod_bias"] + noise = vars[source_name + "/noise_strength"] + bias = vars[source_name + "/bias"] + + dic = { + f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)), + f"to_style{numeral}.weight": mod_weight.transpose((1, 0)), + f"to_style{numeral}.bias": mod_bias + 1, + f"noise{numeral}_scale": np.array([noise]), + f"activation{numeral}.bias": bias, + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + if flip: + dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip( + dic_torch[target_name + f".conv{numeral}.weight"], [2, 3] + ) + + return dic_torch + + +def convert_conv(vars, source_name, target_name, bias=True, start=0): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + + dic = {"weight": weight.transpose((3, 2, 0, 1))} + + if bias: + dic["bias"] = vars[source_name + "/bias"] + + dic_torch = {} + + dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) + + if bias: + dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) + + return dic_torch + + +def convert_torgb(vars, source_name, target_name): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + mod_weight = vars[source_name + "/mod_weight"] + mod_bias = vars[source_name + "/mod_bias"] + bias = vars[source_name + "/bias"] + + dic = { + "conv.weight": weight.transpose((3, 2, 0, 1)), + "to_style.weight": mod_weight.transpose((1, 0)), + "to_style.bias": mod_bias + 1, + # "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this? + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + return dic_torch + + +def convert_dense(vars, source_name, target_name): + vars, source_name = get_vars(vars, source_name) + weight = vars[source_name + "/weight"] + bias = vars[source_name + "/bias"] + + dic = {"weight": weight.transpose((1, 0)), "bias": bias} + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + "." + k] = torch.from_numpy(v) + + return dic_torch + + +def update(state_dict, new, strict=True): + + for k, v in new.items(): + if strict: + if k not in state_dict: + raise KeyError(k + " is not found") + + if v.shape != state_dict[k].shape: + raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") + + state_dict[k] = v + + +def discriminator_fill_statedict(statedict, vars, size): + log_size = int(math.log(size, 2)) + + update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) + + conv_i = 1 + + for i in range(log_size - 2, 0, -1): + reso = 4 * 2 ** i + update( + statedict, + convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), + ) + update( + statedict, + convert_conv( + vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 + ), + ) + update( + statedict, + convert_conv( + vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False + ), + ) + conv_i += 1 + + update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) + update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) + update(statedict, convert_dense(vars, f"Output", "final_linear.1")) + + return statedict + + +def fill_statedict(state_dict, vars, size): + log_size = int(math.log(size, 2)) + + for i in range(8): + update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}")) + + update( + state_dict, + { + "gen.initial_block": torch.from_numpy( + get_vars_direct(vars, "G_synthesis/4x4/Const/const") + ) + }, + ) + + for i in range(log_size - 1): + reso = 4 * 2 ** i + update( + state_dict, + convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"), + ) + + update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1)) + + for i in range(1, log_size - 1): + reso = 4 * 2 ** i + update( + state_dict, + convert_modconv( + vars, + f"G_synthesis/{reso}x{reso}/Conv0_up", + f"gen.blocks.{i}", + #flip=True, # TODO: why?? + numeral=1 + ), + ) + update( + state_dict, + convert_modconv( + vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2 + ), + ) + + ''' + TODO: consider porting this, though I dont think it is necessary. + for i in range(0, (log_size - 2) * 2 + 1): + update( + state_dict, + { + f"noises.noise_{i}": torch.from_numpy( + get_vars_direct(vars, f"G_synthesis/noise{i}") + ) + }, + ) + ''' + + return state_dict + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser( + description="Tensorflow to pytorch model checkpoint converter" + ) + parser.add_argument( + "--gen", action="store_true", help="convert the generator weights" + ) + parser.add_argument( + "--channel_multiplier", + type=int, + default=2, + help="channel multiplier factor. config-f = 2, else = 1", + ) + parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") + + args = parser.parse_args() + sys.path.append('scripts\\stylegan2') + + import dnnlib + from dnnlib.tflib.network import generator, gen_ema + + with open(args.path, "rb") as f: + pickle.load(f) + + # Weight names are ordered by size. The last name will be something like '1024x1024/'. We just need to grab that first number. + size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0]) + + g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8) + state_dict = g.state_dict() + state_dict = fill_statedict(state_dict, gen_ema, size) + + g.load_state_dict(state_dict, strict=True) + + latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg")) + + ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} + + if args.gen: + g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) + g_train_state = g_train.state_dict() + g_train_state = fill_statedict(g_train_state, generator, size) + ckpt["g"] = g_train_state + + name = os.path.splitext(os.path.basename(args.path))[0] + torch.save(ckpt, name + ".pt") + + batch_size = {256: 16, 512: 9, 1024: 4} + n_sample = batch_size.get(size, 25) + + g = g.to(device) + + z = np.random.RandomState(5).randn(n_sample, 512).astype("float32") + + with torch.no_grad(): + img_pt, _ = g(8) + + utils.save_image( + img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) + )