import math import multiprocessing from contextlib import contextmanager, ExitStack import torch import torch.nn.functional as F from kornia.filters import filter2D from linear_attention_transformer import ImageLinearAttention from torch import nn from torch.autograd import grad as torch_grad assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' num_cores = multiprocessing.cpu_count() # constants EPS = 1e-8 class NanException(Exception): pass class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if not exists(old): return new return old * self.beta + (1 - self.beta) * new class Flatten(nn.Module): def forward(self, x): return x.reshape(x.shape[0], -1) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class Rezero(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn self.g = nn.Parameter(torch.zeros(1)) def forward(self, x): return self.fn(x) * self.g class PermuteToFrom(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): x = x.permute(0, 2, 3, 1) out, loss = self.fn(x) out = out.permute(0, 3, 1, 2) return out, loss class Blur(nn.Module): def __init__(self): super().__init__() f = torch.Tensor([1, 2, 1]) self.register_buffer('f', f) def forward(self, x): f = self.f f = f[None, None, :] * f[None, :, None] return filter2D(x, f, normalized=True) # one layer of self-attention and feedforward, for images attn_and_ff = lambda chan: nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) # helpers def exists(val): return val is not None @contextmanager def null_context(): yield def combine_contexts(contexts): @contextmanager def multi_contexts(): with ExitStack() as stack: yield [stack.enter_context(ctx()) for ctx in contexts] return multi_contexts def default(value, d): return value if exists(value) else d def cycle(iterable): while True: for i in iterable: yield i def cast_list(el): return el if isinstance(el, list) else [el] def is_empty(t): if isinstance(t, torch.Tensor): return t.nelement() == 0 return not exists(t) def raise_if_nan(t): if torch.isnan(t): raise NanException def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): if is_ddp: num_no_syncs = gradient_accumulate_every - 1 head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs tail = [null_context] contexts = head + tail else: contexts = [null_context] * gradient_accumulate_every for context in contexts: with context(): yield def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): if fp16: with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: scaled_loss.backward(**kwargs) else: loss.backward(**kwargs) def calc_pl_lengths(styles, images): device = images.device num_pixels = images.shape[2] * images.shape[3] pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) outputs = (images * pl_noise).sum() pl_grads = torch_grad(outputs=outputs, inputs=styles, grad_outputs=torch.ones(outputs.shape, device=device), create_graph=True, retain_graph=True, only_inputs=True)[0] return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() def image_noise(n, im_size, device): return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) def leaky_relu(p=0.2): return nn.LeakyReLU(p, inplace=True) def evaluate_in_chunks(max_batch_size, model, *args): split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) def set_requires_grad(model, bool): for p in model.parameters(): p.requires_grad = bool def slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res # stylegan2 classes class EqualLinear(nn.Module): def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim)) self.lr_mul = lr_mul def forward(self, input): return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) class StyleVectorizer(nn.Module): def __init__(self, emb, depth, lr_mul=0.1): super().__init__() layers = [] for i in range(depth): layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()]) self.net = nn.Sequential(*layers) def forward(self, x): x = F.normalize(x, dim=1) return self.net(x) class RGBBlock(nn.Module): def __init__(self, latent_dim, input_channel, upsample, rgba=False): super().__init__() self.input_channel = input_channel self.to_style = nn.Linear(latent_dim, input_channel) out_filters = 3 if not rgba else 4 self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), Blur() ) if upsample else None def forward(self, x, prev_rgb, istyle): b, c, h, w = x.shape style = self.to_style(istyle) x = self.conv(x, style) if exists(prev_rgb): x = x + prev_rgb if exists(self.upsample): x = self.upsample(x) return x class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel, style_dim): super().__init__() from models.archs.arch_util import ConvGnLelu self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) self.norm = nn.InstanceNorm2d(in_channel) def forward(self, input, style): gamma = self.style2scale(style) beta = self.style2bias(style) out = self.norm(input) out = gamma * out + beta return out class NoiseInjection(nn.Module): def __init__(self, channel): super().__init__() self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) def forward(self, image, noise): return image + self.weight * noise class EqualLR: def __init__(self, name): self.name = name def compute_weight(self, module): weight = getattr(module, self.name + '_orig') fan_in = weight.data.size(1) * weight.data[0][0].numel() return weight * math.sqrt(2 / fan_in) @staticmethod def apply(module, name): fn = EqualLR(name) weight = getattr(module, name) del module._parameters[name] module.register_parameter(name + '_orig', nn.Parameter(weight.data)) module.register_forward_pre_hook(fn) return fn def __call__(self, module, input): weight = self.compute_weight(module) setattr(module, self.name, weight) def equal_lr(module, name='weight'): EqualLR.apply(module, name) return module class EqualConv2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() conv = nn.Conv2d(*args, **kwargs) conv.weight.data.normal_() conv.bias.data.zero_() self.conv = equal_lr(conv) def forward(self, input): return self.conv(input) class Conv2DMod(nn.Module): def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): super().__init__() self.filters = out_chan self.demod = demod self.kernel = kernel self.stride = stride self.dilation = dilation self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') def _get_same_padding(self, size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 def forward(self, x, y): b, c, h, w = x.shape w1 = y[:, None, :, None, None] w2 = self.weight[None, :, :, :, :] weights = w2 * (w1 + 1) if self.demod: d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) weights = weights * d x = x.reshape(1, -1, h, w) _, _, *ws = weights.shape weights = weights.reshape(b * self.filters, *ws) padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) x = F.conv2d(x, weights, padding=padding, groups=b) x = x.reshape(-1, self.filters, h, w) return x class GeneratorBlock(nn.Module): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None self.structure_input = structure_input if self.structure_input: self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) input_channels = input_channels * 2 self.to_style1 = nn.Linear(latent_dim, input_channels) self.to_noise1 = nn.Linear(1, filters) self.conv1 = Conv2DMod(input_channels, filters, 3) self.to_style2 = nn.Linear(latent_dim, filters) self.to_noise2 = nn.Linear(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) self.activation = leaky_relu() self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) def forward(self, x, prev_rgb, istyle, inoise, structure_input=None): if exists(self.upsample): x = self.upsample(x) if self.structure_input: s = self.structure_conv(structure_input) x = torch.cat([x, s], dim=1) inoise = inoise[:, :x.shape[2], :x.shape[3], :] noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) style1 = self.to_style1(istyle) x = self.conv1(x, style1) x = self.activation(x + noise1) style2 = self.to_style2(istyle) x = self.conv2(x, style2) x = self.activation(x + noise2) rgb = self.to_rgb(x, prev_rgb, istyle) return x, rgb