# Note this is an attempt to conform the lucidrains stylegan implementation to the official reference spec so that # I could use pretrained weights from NVIDIA. It is not currently successful, but that may be due to the weight # converter and not the code changes here. Use at your own risk. import math import multiprocessing from contextlib import contextmanager, ExitStack from functools import partial from math import log2 from random import random import torch import torch.nn.functional as F import trainer.losses as L import numpy as np from kornia.filters import filter2D from linear_attention_transformer import ImageLinearAttention from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize from utils.util import checkpoint try: from apex import amp APEX_AVAILABLE = True except: APEX_AVAILABLE = False assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' num_cores = multiprocessing.cpu_count() # constants EPS = 1e-8 CALC_FID_NUM_IMAGES = 12800 # helper classes def DiffAugment(x, types=[]): for p in types: for f in AUGMENT_FNS[p]: x = f(x) return x.contiguous() def rand_brightness(x): x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) return x def rand_saturation(x): x_mean = x.mean(dim=1, keepdim=True) x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean return x def rand_contrast(x): x_mean = x.mean(dim=[1, 2, 3], keepdim=True) x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean return x def rand_translation(x, ratio=0.125): shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(x.size(2), dtype=torch.long, device=x.device), torch.arange(x.size(3), dtype=torch.long, device=x.device), ) grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) return x def rand_cutout(x, ratio=0.5): cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(cutout_size[0], dtype=torch.long, device=x.device), torch.arange(cutout_size[1], dtype=torch.long, device=x.device), ) grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) mask[grid_batch, grid_x, grid_y] = 0 x = x * mask.unsqueeze(1) return x AUGMENT_FNS = { 'color': [rand_brightness, rand_saturation, rand_contrast], 'translation': [rand_translation], 'cutout': [rand_cutout], } class NanException(Exception): pass class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if not exists(old): return new return old * self.beta + (1 - self.beta) * new class Flatten(nn.Module): def forward(self, x): return x.reshape(x.shape[0], -1) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class Rezero(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn self.g = nn.Parameter(torch.zeros(1)) def forward(self, x): return self.fn(x) * self.g class PermuteToFrom(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): x = x.permute(0, 2, 3, 1) out, loss = self.fn(x) out = out.permute(0, 3, 1, 2) return out, loss class Blur(nn.Module): def __init__(self): super().__init__() f = torch.Tensor([1, 2, 1]) self.register_buffer('f', f) def forward(self, x): f = self.f f = f[None, None, :] * f[None, :, None] return filter2D(x, f, normalized=True) # one layer of self-attention and feedforward, for images attn_and_ff = lambda chan: nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) # helpers def exists(val): return val is not None @contextmanager def null_context(): yield def combine_contexts(contexts): @contextmanager def multi_contexts(): with ExitStack() as stack: yield [stack.enter_context(ctx()) for ctx in contexts] return multi_contexts def default(value, d): return value if exists(value) else d def cycle(iterable): while True: for i in iterable: yield i def cast_list(el): return el if isinstance(el, list) else [el] def is_empty(t): if isinstance(t, torch.Tensor): return t.nelement() == 0 return not exists(t) def raise_if_nan(t): if torch.isnan(t): raise NanException def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): if is_ddp: num_no_syncs = gradient_accumulate_every - 1 head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs tail = [null_context] contexts = head + tail else: contexts = [null_context] * gradient_accumulate_every for context in contexts: with context(): yield def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): if fp16: with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: scaled_loss.backward(**kwargs) else: loss.backward(**kwargs) def gradient_penalty(images, output, weight=10): batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.reshape(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() def calc_pl_lengths(styles, images): device = images.device num_pixels = images.shape[2] * images.shape[3] pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) outputs = (images * pl_noise).sum() pl_grads = torch_grad(outputs=outputs, inputs=styles, grad_outputs=torch.ones(outputs.shape, device=device), create_graph=True, retain_graph=True, only_inputs=True)[0] return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() def image_noise(n, im_size, device): return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) class BiasedLeakyReLU(nn.Module): def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): super().__init__() if bias: self.bias = nn.Parameter(torch.zeros(channel)) else: self.bias = None self.negative_slope = negative_slope self.scale = scale def forward(self, input): return biased_leaky_relu(input, self.bias, self.negative_slope, self.scale) def biased_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): if bias is not None: rest_dim = [1] * (input.ndim - bias.ndim - 1) return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope ) * scale ) else: return F.leaky_relu(input, negative_slope=0.2) * scale def evaluate_in_chunks(max_batch_size, model, *args): split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) def set_requires_grad(model, bool): for p in model.parameters(): p.requires_grad = bool def slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res # augmentations def random_hflip(tensor, prob): if prob > random(): return tensor return torch.flip(tensor, dims=(3,)) class StyleGan2Augmentor(nn.Module): def __init__(self, D, image_size, types, prob): super().__init__() self.D = D self.prob = prob self.types = types def forward(self, images, detach=False): if random() < self.prob: images = random_hflip(images, prob=0.5) images = DiffAugment(images, types=self.types) if detach: images = images.detach() # Save away for use elsewhere (e.g. unet loss) self.aug_images = images return self.D(images) # stylegan2 classes class EqualLinear(nn.Module): def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, activation=False): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim)) self.lr_mul = lr_mul self.activation = activation def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) else: out = F.linear(input, self.weight * self.lr_mul) out = biased_leaky_relu(out, self.bias * self.lr_mul) return out class StyleVectorizer(nn.Module): def __init__(self, emb, depth, lr_mul=0.01): super().__init__() layers = [] for i in range(depth): layers.extend([EqualLinear(emb, emb, lr_mul, activation=True)]) self.net = nn.Sequential(*layers) def forward(self, x): x = F.normalize(x, dim=1) return self.net(x) class RGBBlock(nn.Module): def __init__(self, latent_dim, input_channel, upsample, rgba=False): super().__init__() self.input_channel = input_channel self.to_style = nn.Linear(latent_dim, input_channel) out_filters = 3 if not rgba else 4 self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), Blur() ) if upsample else None def forward(self, x, prev_rgb, istyle): b, c, h, w = x.shape style = self.to_style(istyle) x = self.conv(x, style) if exists(prev_rgb): x = x + prev_rgb if exists(self.upsample): x = self.upsample(x) return x class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel, style_dim): super().__init__() from models.arch_util import ConvGnLelu self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) self.norm = nn.InstanceNorm2d(in_channel) def forward(self, input, style): gamma = self.style2scale(style) beta = self.style2bias(style) out = self.norm(input) out = gamma * out + beta return out class NoiseInjection(nn.Module): def __init__(self, channel): super().__init__() self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) def forward(self, image, noise): return image + self.weight * noise class EqualLR: def __init__(self, name): self.name = name def compute_weight(self, module): weight = getattr(module, self.name + '_orig') fan_in = weight.data.size(1) * weight.data[0][0].numel() return weight * math.sqrt(2 / fan_in) @staticmethod def apply(module, name): fn = EqualLR(name) weight = getattr(module, name) del module._parameters[name] module.register_parameter(name + '_orig', nn.Parameter(weight.data)) module.register_forward_pre_hook(fn) return fn def __call__(self, module, input): weight = self.compute_weight(module) setattr(module, self.name, weight) def equal_lr(module, name='weight'): EqualLR.apply(module, name) return module class EqualConv2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() conv = nn.Conv2d(*args, **kwargs) conv.weight.data.normal_() conv.bias.data.zero_() self.conv = equal_lr(conv) def forward(self, input): return self.conv(input) class Conv2DMod(nn.Module): def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): super().__init__() self.filters = out_chan self.demod = demod self.kernel = kernel self.stride = stride self.dilation = dilation self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') def _get_same_padding(self, size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 def forward(self, x, y): b, c, h, w = x.shape w1 = y[:, None, :, None, None] w2 = self.weight[None, :, :, :, :] weights = w2 * (w1 + 1) if self.demod: d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) weights = weights * d x = x.reshape(1, -1, h, w) _, _, *ws = weights.shape weights = weights.reshape(b * self.filters, *ws) padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) x = F.conv2d(x, weights, padding=padding, groups=b) x = x.reshape(-1, self.filters, h, w) return x class GeneratorBlock(nn.Module): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, initial_block=False): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None self.to_style1 = nn.Linear(latent_dim, input_channels) self.noise1_scale = nn.Parameter(torch.full((1,), fill_value=1e-5)) self.conv1 = Conv2DMod(input_channels, filters, 3) self.activation1 = BiasedLeakyReLU(filters) self.initial_block = initial_block if not initial_block: self.to_style2 = nn.Linear(latent_dim, filters) self.noise2_scale = nn.Parameter(torch.full((1,), fill_value=1e-5)) self.conv2 = Conv2DMod(filters, filters, 3) self.activation2 = BiasedLeakyReLU(filters) self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) def forward(self, x, prev_rgb, istyle, inoise1=None, inoise2=None): if exists(self.upsample): x = self.upsample(x) if inoise1 is None: b, c, h, w = x.shape inoise1 = torch.randn((b,1,h,w), device=x.device) inoise2 = torch.randn((b,1,h,w), device=x.device) # Assume that both are None if one is None. noise1 = inoise1 * self.noise1_scale style1 = self.to_style1(istyle) x = self.conv1(x, style1) x = self.activation1(x + noise1) if not self.initial_block: noise2 = inoise2 * self.noise2_scale style2 = self.to_style2(istyle) x = self.conv2(x, style2) x = self.activation2(x + noise2) rgb = self.to_rgb(x, prev_rgb, istyle) return x, rgb class Generator(nn.Module): def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512): super().__init__() self.image_size = image_size self.latent_dim = latent_dim self.num_layers = int(log2(image_size) - 1) filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) init_channels = filters[0] filters = [init_channels, *filters] in_out_pairs = zip(filters[:-1], filters[1:]) self.no_const = no_const if no_const: self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) else: self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) self.blocks = nn.ModuleList([]) self.attns = nn.ModuleList([]) for ind, (in_chan, out_chan) in enumerate(in_out_pairs): not_first = ind != 0 not_last = ind != (self.num_layers - 1) num_layer = self.num_layers - ind attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None self.attns.append(attn_fn) block_fn = GeneratorBlock block = block_fn( latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, rgba=transparent, initial_block=(ind == 0) ) self.blocks.append(block) def forward(self, styles, input_noises): batch_size = styles.shape[0] if self.no_const: avg_style = styles.mean(dim=1)[:, :, None, None] x = self.to_initial_block(avg_style) else: x = self.initial_block.expand(batch_size, -1, -1, -1) rgb = None styles = styles.transpose(0, 1) n = 0 for style, block, attn in zip(styles, self.blocks, self.attns): if exists(attn): x = checkpoint(attn, x) x, rgb = checkpoint(block, x, rgb, style, input_noises[n], input_noises[n+1]) n = 1 if n == 0 else n + 2 # The first block only consumes 1 noise, the rest consume 2. return rgb # Wrapper that combines style vectorizer with the actual generator. class StyleGan2GeneratorWithLatent(nn.Module): def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512): super().__init__() self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max) self.mixed_prob = .9 self._init_weights() def noise(self, n, latent_dim, device): return torch.randn(n, latent_dim).cuda(device) def noise_list(self, n, layers, latent_dim, device): return [(self.noise(n, latent_dim, device), layers)] def mixed_list(self, n, layers, latent_dim, device): tt = int(torch.rand(()).numpy() * layers) return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device) def latent_to_w(self, style_vectorizer, latent_descr): return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] def styles_def_to_tensor(self, styles_def): return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) # If provided, 'noise' should be a list of tensors that is fed into each input block. # b=batch_size. def forward(self, b, noises=None): if noises is None: noises = [None] * (len(self.gen.blocks) * 2 - 1) full_random_latents = True if full_random_latents: style = self.noise(b*2, self.gen.latent_dim, next(self.parameters()).device) w = self.vectorizer(style) # Randomly distribute styles across layers w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() for j in range(b): cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) if cutoff == self.gen.num_layers or random() > self.mixed_prob: w_styles[j] = w_styles[j*2] else: w_styles[j, :cutoff] = w_styles[j*2, :cutoff] w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] w_styles = w_styles[:b] else: get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device) w_space = self.latent_to_w(self.vectorizer, style) w_styles = self.styles_def_to_tensor(w_space) return self.gen(w_styles, noises), w_styles def _init_weights(self): for m in self.modules(): if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') class DiscriminatorBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.net = nn.Sequential( nn.Conv2d(input_channels, filters, 3, padding=1), leaky_relu(), nn.Conv2d(filters, filters, 3, padding=1), leaky_relu() ) self.downsample = nn.Sequential( Blur(), nn.Conv2d(filters, filters, 3, padding=1, stride=2) ) if downsample else None def forward(self, x): res = self.conv_res(x) x = self.net(x) if exists(self.downsample): x = self.downsample(x) x = (x + res) * (1 / math.sqrt(2)) return x class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512, input_filters=3): super().__init__() num_layers = int(log2(image_size) - 1) blocks = [] filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] attn_blocks = [] quantize_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn) quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) chan_last = filters[-1] latent_dim = 2 * 2 * chan_last self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() self.to_logit = nn.Linear(latent_dim, 1) self._init_weights() def forward(self, x): b, *_ = x.shape quantize_loss = torch.zeros(1).to(x) for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): x = block(x) if exists(attn_block): x = attn_block(x) if exists(q_block): x, _, loss = q_block(x) quantize_loss += loss x = self.final_conv(x) x = self.flatten(x) x = self.to_logit(x) if exists(q_block): return x.squeeze(), quantize_loss else: return x.squeeze() def _init_weights(self): for m in self.modules(): if type(m) in {nn.Conv2d, nn.Linear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') class StyleGan2DivergenceLoss(L.ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.real = opt['real'] self.fake = opt['fake'] self.discriminator = opt['discriminator'] self.for_gen = opt['gen_loss'] self.gp_frequency = opt['gradient_penalty_frequency'] self.noise = opt['noise'] if 'noise' in opt.keys() else 0 def forward(self, net, state): real_input = state[self.real] fake_input = state[self.fake] if self.noise != 0: fake_input = fake_input + torch.rand_like(fake_input) * self.noise real_input = real_input + torch.rand_like(real_input) * self.noise D = self.env['discriminators'][self.discriminator] fake = D(fake_input) if self.for_gen: return fake.mean() else: real_input.requires_grad_() # <-- Needed to compute gradients on the input. real = D(real_input) divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: from models.stylegan.stylegan2_lucidrains import gradient_penalty gp = gradient_penalty(real_input, real) self.metrics.append(("gradient_penalty", gp.clone().detach())) divergence_loss = divergence_loss + gp real_input.requires_grad_(requires_grad=False) return divergence_loss class StyleGan2PathLengthLoss(L.ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.w_styles = opt['w_styles'] self.gen = opt['gen'] self.pl_mean = None from models.stylegan.stylegan2_lucidrains import EMA self.pl_length_ma = EMA(.99) def forward(self, net, state): w_styles = state[self.w_styles] gen = state[self.gen] from models.stylegan.stylegan2_lucidrains import calc_pl_lengths pl_lengths = calc_pl_lengths(w_styles, gen) avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) from models.stylegan.stylegan2_lucidrains import is_empty if not is_empty(self.pl_mean): pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() if not torch.isnan(pl_loss): return pl_loss else: print("Path length loss returned NaN!") self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) return 0