412 lines
12 KiB
Python
412 lines
12 KiB
Python
import math
|
|
import multiprocessing
|
|
from contextlib import contextmanager, ExitStack
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from kornia.filters import filter2D
|
|
from linear_attention_transformer import ImageLinearAttention
|
|
from torch import nn, Tensor
|
|
from torch.autograd import grad as torch_grad
|
|
from torch.nn import Parameter, init
|
|
from torch.nn.modules.conv import _ConvNd
|
|
|
|
from models.styled_sr.transfer_primitives import TransferLinear
|
|
|
|
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
|
|
|
num_cores = multiprocessing.cpu_count()
|
|
|
|
# constants
|
|
EPS = 1e-8
|
|
|
|
|
|
class NanException(Exception):
|
|
pass
|
|
|
|
|
|
class EMA():
|
|
def __init__(self, beta):
|
|
super().__init__()
|
|
self.beta = beta
|
|
|
|
def update_average(self, old, new):
|
|
if not exists(old):
|
|
return new
|
|
return old * self.beta + (1 - self.beta) * new
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
def forward(self, x):
|
|
return x.reshape(x.shape[0], -1)
|
|
|
|
|
|
class Residual(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x) + x
|
|
|
|
|
|
class Rezero(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.g = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x):
|
|
return self.fn(x) * self.g
|
|
|
|
|
|
class PermuteToFrom(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
x = x.permute(0, 2, 3, 1)
|
|
out, loss = self.fn(x)
|
|
out = out.permute(0, 3, 1, 2)
|
|
return out, loss
|
|
|
|
|
|
class Blur(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
f = torch.Tensor([1, 2, 1])
|
|
self.register_buffer('f', f)
|
|
|
|
def forward(self, x):
|
|
f = self.f
|
|
f = f[None, None, :] * f[None, :, None]
|
|
return filter2D(x, f, normalized=True)
|
|
|
|
|
|
# one layer of self-attention and feedforward, for images
|
|
|
|
attn_and_ff = lambda chan: nn.Sequential(*[
|
|
Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
|
|
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
|
|
])
|
|
|
|
|
|
# helpers
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
@contextmanager
|
|
def null_context():
|
|
yield
|
|
|
|
|
|
def combine_contexts(contexts):
|
|
@contextmanager
|
|
def multi_contexts():
|
|
with ExitStack() as stack:
|
|
yield [stack.enter_context(ctx()) for ctx in contexts]
|
|
|
|
return multi_contexts
|
|
|
|
|
|
def default(value, d):
|
|
return value if exists(value) else d
|
|
|
|
|
|
def cycle(iterable):
|
|
while True:
|
|
for i in iterable:
|
|
yield i
|
|
|
|
|
|
def cast_list(el):
|
|
return el if isinstance(el, list) else [el]
|
|
|
|
|
|
def is_empty(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.nelement() == 0
|
|
return not exists(t)
|
|
|
|
|
|
def raise_if_nan(t):
|
|
if torch.isnan(t):
|
|
raise NanException
|
|
|
|
|
|
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
|
|
if is_ddp:
|
|
num_no_syncs = gradient_accumulate_every - 1
|
|
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
|
|
tail = [null_context]
|
|
contexts = head + tail
|
|
else:
|
|
contexts = [null_context] * gradient_accumulate_every
|
|
|
|
for context in contexts:
|
|
with context():
|
|
yield
|
|
|
|
|
|
def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
|
|
if fp16:
|
|
with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss:
|
|
scaled_loss.backward(**kwargs)
|
|
else:
|
|
loss.backward(**kwargs)
|
|
|
|
def calc_pl_lengths(styles, images):
|
|
device = images.device
|
|
num_pixels = images.shape[2] * images.shape[3]
|
|
pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels)
|
|
outputs = (images * pl_noise).sum()
|
|
|
|
pl_grads = torch_grad(outputs=outputs, inputs=styles,
|
|
grad_outputs=torch.ones(outputs.shape, device=device),
|
|
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
|
|
|
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
|
|
|
|
|
|
def image_noise(n, im_size, device):
|
|
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device)
|
|
|
|
|
|
def leaky_relu(p=0.2):
|
|
return nn.LeakyReLU(p, inplace=True)
|
|
|
|
|
|
def evaluate_in_chunks(max_batch_size, model, *args):
|
|
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
|
|
chunked_outputs = [model(*i) for i in split_args]
|
|
if len(chunked_outputs) == 1:
|
|
return chunked_outputs[0]
|
|
return torch.cat(chunked_outputs, dim=0)
|
|
|
|
|
|
def set_requires_grad(model, bool):
|
|
for p in model.parameters():
|
|
p.requires_grad = bool
|
|
|
|
|
|
def slerp(val, low, high):
|
|
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
|
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
|
omega = torch.acos((low_norm * high_norm).sum(1))
|
|
so = torch.sin(omega)
|
|
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
|
return res
|
|
|
|
|
|
class EqualLinear(nn.Module):
|
|
def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, transfer_mode=False):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.zeros(out_dim))
|
|
|
|
self.lr_mul = lr_mul
|
|
|
|
self.transfer_mode = transfer_mode
|
|
if transfer_mode:
|
|
self.transfer_scale = nn.Parameter(torch.ones(out_features, in_features))
|
|
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
self.transfer_shift = nn.Parameter(torch.zeros(out_features, in_features))
|
|
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
|
|
def forward(self, input):
|
|
if self.transfer_mode:
|
|
weight = self.weight * self.transfer_scale + self.transfer_shift
|
|
else:
|
|
weight = self.weight
|
|
return F.linear(input, weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
|
|
|
|
|
class StyleVectorizer(nn.Module):
|
|
def __init__(self, emb, depth, lr_mul=0.1, transfer_mode=False):
|
|
super().__init__()
|
|
|
|
layers = []
|
|
for i in range(depth):
|
|
layers.extend([EqualLinear(emb, emb, lr_mul, transfer_mode=transfer_mode), leaky_relu()])
|
|
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = F.normalize(x, dim=1)
|
|
return self.net(x)
|
|
|
|
|
|
class RGBBlock(nn.Module):
|
|
def __init__(self, latent_dim, input_channel, upsample, rgba=False, transfer_mode=False):
|
|
super().__init__()
|
|
self.input_channel = input_channel
|
|
self.to_style = nn.Linear(latent_dim, input_channel)
|
|
|
|
out_filters = 3 if not rgba else 4
|
|
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False, transfer_mode=transfer_mode)
|
|
|
|
self.upsample = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
|
Blur()
|
|
) if upsample else None
|
|
|
|
def forward(self, x, prev_rgb, istyle):
|
|
b, c, h, w = x.shape
|
|
style = self.to_style(istyle)
|
|
x = self.conv(x, style)
|
|
|
|
if exists(prev_rgb):
|
|
x = x + prev_rgb
|
|
|
|
if exists(self.upsample):
|
|
x = self.upsample(x)
|
|
|
|
return x
|
|
|
|
|
|
class AdaptiveInstanceNorm(nn.Module):
|
|
def __init__(self, in_channel, style_dim):
|
|
super().__init__()
|
|
from models.archs.arch_util import ConvGnLelu
|
|
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
|
|
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
|
|
self.norm = nn.InstanceNorm2d(in_channel)
|
|
|
|
def forward(self, input, style):
|
|
gamma = self.style2scale(style)
|
|
beta = self.style2bias(style)
|
|
out = self.norm(input)
|
|
out = gamma * out + beta
|
|
return out
|
|
|
|
|
|
class NoiseInjection(nn.Module):
|
|
def __init__(self, channel):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
|
|
|
def forward(self, image, noise):
|
|
return image + self.weight * noise
|
|
|
|
|
|
class EqualLR:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def compute_weight(self, module):
|
|
weight = getattr(module, self.name + '_orig')
|
|
fan_in = weight.data.size(1) * weight.data[0][0].numel()
|
|
|
|
return weight * math.sqrt(2 / fan_in)
|
|
|
|
@staticmethod
|
|
def apply(module, name):
|
|
fn = EqualLR(name)
|
|
|
|
weight = getattr(module, name)
|
|
del module._parameters[name]
|
|
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def __call__(self, module, input):
|
|
weight = self.compute_weight(module)
|
|
setattr(module, self.name, weight)
|
|
|
|
|
|
def equal_lr(module, name='weight'):
|
|
EqualLR.apply(module, name)
|
|
return module
|
|
|
|
|
|
class Conv2DMod(nn.Module):
|
|
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, transfer_mode=False, **kwargs):
|
|
super().__init__()
|
|
self.filters = out_chan
|
|
self.demod = demod
|
|
self.kernel = kernel
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
|
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
self.transfer_mode = transfer_mode
|
|
if transfer_mode:
|
|
self.transfer_scale = nn.Parameter(torch.ones(out_chan, in_chan, 1, 1))
|
|
self.transfer_scale.FOR_TRANSFER_LEARNING = True
|
|
self.transfer_shift = nn.Parameter(torch.zeros(out_chan, in_chan, 1, 1))
|
|
self.transfer_shift.FOR_TRANSFER_LEARNING = True
|
|
|
|
def _get_same_padding(self, size, kernel, dilation, stride):
|
|
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
|
|
|
def forward(self, x, y):
|
|
b, c, h, w = x.shape
|
|
|
|
if self.transfer_mode:
|
|
weight = self.weight * self.transfer_scale + self.transfer_shift
|
|
else:
|
|
weight = self.weight
|
|
|
|
w1 = y[:, None, :, None, None]
|
|
w2 = weight[None, :, :, :, :]
|
|
weights = w2 * (w1 + 1)
|
|
|
|
if self.demod:
|
|
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
|
|
weights = weights * d
|
|
|
|
x = x.reshape(1, -1, h, w)
|
|
|
|
_, _, *ws = weights.shape
|
|
weights = weights.reshape(b * self.filters, *ws)
|
|
|
|
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
|
x = F.conv2d(x, weights, padding=padding, groups=b)
|
|
|
|
x = x.reshape(-1, self.filters, h, w)
|
|
return x
|
|
|
|
|
|
class GeneratorBlock(nn.Module):
|
|
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False,
|
|
transfer_learning_mode=False):
|
|
super().__init__()
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
|
|
|
self.to_style1 = TransferLinear(latent_dim, input_channels, transfer_mode=transfer_learning_mode)
|
|
self.to_noise1 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
|
self.conv1 = Conv2DMod(input_channels, filters, 3, transfer_mode=transfer_learning_mode)
|
|
|
|
self.to_style2 = TransferLinear(latent_dim, filters, transfer_mode=transfer_learning_mode)
|
|
self.to_noise2 = TransferLinear(1, filters, transfer_mode=transfer_learning_mode)
|
|
self.conv2 = Conv2DMod(filters, filters, 3, transfer_mode=transfer_learning_mode)
|
|
|
|
self.activation = leaky_relu()
|
|
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba, transfer_mode=transfer_learning_mode)
|
|
|
|
self.transfer_learning_mode = transfer_learning_mode
|
|
|
|
def forward(self, x, prev_rgb, istyle, inoise):
|
|
if exists(self.upsample):
|
|
x = self.upsample(x)
|
|
|
|
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
|
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
|
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
|
|
|
style1 = self.to_style1(istyle)
|
|
x = self.conv1(x, style1)
|
|
x = self.activation(x + noise1)
|
|
|
|
style2 = self.to_style2(istyle)
|
|
x = self.conv2(x, style2)
|
|
x = self.activation(x + noise2)
|
|
|
|
rgb = self.to_rgb(x, prev_rgb, istyle)
|
|
return x, rgb
|