ee8ceed6da
- use a gated activation layer for both attention & convs - add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned - get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that?
1099 lines
39 KiB
Python
1099 lines
39 KiB
Python
import math
|
|
from abc import abstractmethod
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
import torch.nn.functional as F
|
|
import torch.nn.utils.spectral_norm as SpectralNorm
|
|
from math import sqrt
|
|
|
|
from utils.util import checkpoint
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, p = 2, dim = -1)
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay):
|
|
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
|
|
|
|
|
|
def laplace_smoothing(x, n_categories, eps = 1e-5):
|
|
return (x + eps) / (x.sum() + n_categories * eps)
|
|
|
|
|
|
def sample_vectors(samples, num):
|
|
num_samples, device = samples.shape[0], samples.device
|
|
|
|
if num_samples >= num:
|
|
indices = torch.randperm(num_samples, device = device)[:num]
|
|
else:
|
|
indices = torch.randint(0, num_samples, (num,), device = device)
|
|
|
|
return samples[indices]
|
|
|
|
|
|
def kaiming_init(module,
|
|
a=0,
|
|
mode='fan_out',
|
|
nonlinearity='relu',
|
|
bias=0,
|
|
distribution='normal'):
|
|
assert distribution in ['uniform', 'normal']
|
|
if distribution == 'uniform':
|
|
nn.init.kaiming_uniform_(
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
else:
|
|
nn.init.kaiming_normal_(
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def pixel_norm(x, epsilon=1e-8):
|
|
return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
|
|
|
|
|
|
def initialize_weights(net_l, scale=1):
|
|
if not isinstance(net_l, list):
|
|
net_l = [net_l]
|
|
for net in net_l:
|
|
for m in net.modules():
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
m.weight.data *= scale # for residual block
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.Linear):
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
m.weight.data *= scale
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
init.constant_(m.weight, 1)
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
def make_layer(block, num_blocks, **kwarg):
|
|
"""Make layers by stacking the same blocks.
|
|
Args:
|
|
block (nn.module): nn.module class for basic block.
|
|
num_blocks (int): number of blocks.
|
|
Returns:
|
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
|
"""
|
|
layers = []
|
|
for _ in range(num_blocks):
|
|
layers.append(block(**kwarg))
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def default_init_weights(module, scale=1):
|
|
"""Initialize network weights.
|
|
Args:
|
|
modules (nn.Module): Modules to be initialized.
|
|
scale (float): Scale initialized weights, especially for residual
|
|
blocks.
|
|
"""
|
|
for m in module.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
m.weight.data *= scale
|
|
elif isinstance(m, nn.Linear):
|
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
m.weight.data *= scale
|
|
|
|
|
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
class SiLU(nn.Module):
|
|
def forward(self, x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class GroupNorm32(nn.GroupNorm):
|
|
def forward(self, x):
|
|
return super().forward(x.float()).type(x.dtype)
|
|
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D convolution module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.Conv1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.Conv3d(*args, **kwargs)
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
def linear(*args, **kwargs):
|
|
"""
|
|
Create a linear module.
|
|
"""
|
|
return nn.Linear(*args, **kwargs)
|
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D average pooling module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.AvgPool1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.AvgPool2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.AvgPool3d(*args, **kwargs)
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
def update_ema(target_params, source_params, rate=0.99):
|
|
"""
|
|
Update target parameters to be closer to those of source parameters using
|
|
an exponential moving average.
|
|
|
|
:param target_params: the target parameter sequence.
|
|
:param source_params: the source parameter sequence.
|
|
:param rate: the EMA rate (closer to 1 means slower).
|
|
"""
|
|
for targ, src in zip(target_params, source_params):
|
|
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
|
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
|
|
def scale_module(module, scale):
|
|
"""
|
|
Scale the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().mul_(scale)
|
|
return module
|
|
|
|
|
|
def mean_flat(tensor):
|
|
"""
|
|
Take the mean over all non-batch dimensions.
|
|
"""
|
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
|
|
|
|
|
def normalization(channels):
|
|
"""
|
|
Make a standard normalization layer.
|
|
|
|
:param channels: number of input channels.
|
|
:return: an nn.Module for normalization.
|
|
"""
|
|
groups = 32
|
|
if channels <= 16:
|
|
groups = 8
|
|
elif channels <= 64:
|
|
groups = 16
|
|
while channels % groups != 0:
|
|
groups = int(groups / 2)
|
|
assert groups > 2
|
|
return GroupNorm32(groups, channels)
|
|
|
|
|
|
class AttentionPool2d(nn.Module):
|
|
"""
|
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
spacial_dim: int,
|
|
embed_dim: int,
|
|
num_heads_channels: int,
|
|
output_dim: int = None,
|
|
):
|
|
super().__init__()
|
|
self.positional_embedding = nn.Parameter(
|
|
torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
|
)
|
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
self.num_heads = embed_dim // num_heads_channels
|
|
self.attention = QKVAttention(self.num_heads)
|
|
|
|
def forward(self, x):
|
|
b, c, *_spatial = x.shape
|
|
x = x.reshape(b, c, -1) # NC(HW)
|
|
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
|
x = self.qkv_proj(x)
|
|
x = self.attention(x)
|
|
x = self.c_proj(x)
|
|
return x[:, :, 0]
|
|
|
|
|
|
class TimestepBlock(nn.Module):
|
|
"""
|
|
Any module where forward() takes timestep embeddings as a second argument.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(self, x, emb):
|
|
"""
|
|
Apply the module to `x` given `emb` timestep embeddings.
|
|
"""
|
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
"""
|
|
A sequential module that passes timestep embeddings to the children that
|
|
support it as an extra input.
|
|
"""
|
|
|
|
def forward(self, x, emb):
|
|
for layer in self:
|
|
if isinstance(layer, TimestepBlock):
|
|
x = layer(x, emb)
|
|
else:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
An upsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
upsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
self.factor = factor
|
|
if use_conv:
|
|
ksize = 3
|
|
pad = 1
|
|
if dims == 1:
|
|
ksize = 5
|
|
pad = 2
|
|
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.dims == 3:
|
|
x = F.interpolate(
|
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
)
|
|
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
|
if self.use_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
ksize = 3
|
|
pad = 1
|
|
stride = factor
|
|
if use_conv:
|
|
self.op = conv_nd(
|
|
dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
return self.op(x)
|
|
|
|
|
|
class cGLU(nn.Module):
|
|
"""
|
|
Gated GELU for channel-first architectures.
|
|
"""
|
|
def __init__(self, dim_in, dim_out=None):
|
|
super().__init__()
|
|
dim_out = dim_in if dim_out is None else dim_out
|
|
self.proj = nn.Conv1d(dim_in, dim_out * 2, 1)
|
|
|
|
def forward(self, x):
|
|
x, gate = self.proj(x).chunk(2, dim=1)
|
|
return x * F.gelu(gate)
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
"""
|
|
A residual block that can optionally change the number of channels.
|
|
|
|
:param channels: the number of input channels.
|
|
:param emb_channels: the number of timestep embedding channels.
|
|
:param dropout: the rate of dropout.
|
|
:param out_channels: if specified, the number of out channels.
|
|
:param use_conv: if True and out_channels is specified, use a spatial
|
|
convolution instead of a smaller 1x1 convolution to change the
|
|
channels in the skip connection.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
:param up: if True, use this block for upsampling.
|
|
:param down: if True, use this block for downsampling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
dropout=0,
|
|
out_channels=None,
|
|
use_conv=False,
|
|
dims=2,
|
|
up=False,
|
|
down=False,
|
|
kernel_size=3,
|
|
checkpointing_enabled=True,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.dropout = dropout
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.checkpointing_enabled = checkpointing_enabled
|
|
padding = 1 if kernel_size == 3 else 2
|
|
|
|
self.in_layers = nn.Sequential(
|
|
normalization(channels),
|
|
nn.SiLU(),
|
|
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
|
)
|
|
|
|
self.updown = up or down
|
|
|
|
if up:
|
|
self.h_upd = Upsample(channels, use_conv, dims)
|
|
self.x_upd = Upsample(channels, use_conv, dims)
|
|
elif down:
|
|
self.h_upd = Downsample(channels, use_conv, dims)
|
|
self.x_upd = Downsample(channels, use_conv, dims)
|
|
else:
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
self.out_layers = nn.Sequential(
|
|
normalization(self.out_channels),
|
|
nn.SiLU(),
|
|
nn.Dropout(p=dropout),
|
|
zero_module(
|
|
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
|
),
|
|
)
|
|
|
|
if self.out_channels == channels:
|
|
self.skip_connection = nn.Identity()
|
|
elif use_conv:
|
|
self.skip_connection = conv_nd(
|
|
dims, channels, self.out_channels, kernel_size, padding=padding
|
|
)
|
|
else:
|
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
|
|
:param x: an [N x C x ...] Tensor of features.
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
"""
|
|
if self.checkpointing_enabled:
|
|
return checkpoint(
|
|
self._forward, x
|
|
)
|
|
else:
|
|
return self._forward(x)
|
|
|
|
def _forward(self, x):
|
|
if self.updown:
|
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
h = in_rest(x)
|
|
h = self.h_upd(h)
|
|
x = self.x_upd(x)
|
|
h = in_conv(h)
|
|
else:
|
|
h = self.in_layers(x)
|
|
h = self.out_layers(h)
|
|
return self.skip_connection(x) + h
|
|
|
|
|
|
def build_local_attention_mask(n, l, fixed_region=0):
|
|
"""
|
|
Builds an attention mask that focuses attention on local region
|
|
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
|
Args:
|
|
n: Size of returned matrix (maximum sequence size)
|
|
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
|
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
|
Returns:
|
|
A mask that can be applied to AttentionBlock to achieve local attention.
|
|
"""
|
|
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
|
o = torch.arange(0,n)
|
|
c = o.unsqueeze(-1).repeat(1,n)
|
|
r = o.unsqueeze(0).repeat(n,1)
|
|
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
|
localized[:fixed_region] = 1
|
|
localized[:, :fixed_region] = 1
|
|
mask = localized > 0
|
|
return mask
|
|
|
|
|
|
def test_local_attention_mask():
|
|
print(build_local_attention_mask(9,4,1))
|
|
|
|
|
|
class RelativeQKBias(nn.Module):
|
|
"""
|
|
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
|
|
the distance from the given element.
|
|
"""
|
|
def __init__(self, l, max_positions=4000):
|
|
super().__init__()
|
|
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
|
o = torch.arange(0,max_positions)
|
|
c = o.unsqueeze(-1).repeat(1,max_positions)
|
|
r = o.unsqueeze(0).repeat(max_positions,1)
|
|
M = ((-(r-c).abs())+l).clamp(0,l)
|
|
self.register_buffer('M', M, persistent=False)
|
|
|
|
def forward(self, n):
|
|
return self.emb[self.M[:n, :n]].view(1,n,n)
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
out_channels=None,
|
|
use_new_attention_order=False,
|
|
do_checkpoint=True,
|
|
do_activation=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
out_channels = channels if out_channels is None else out_channels
|
|
self.do_checkpoint = do_checkpoint
|
|
self.do_activation = do_activation
|
|
if num_head_channels == -1:
|
|
self.num_heads = num_heads
|
|
else:
|
|
assert (
|
|
channels % num_head_channels == 0
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
self.num_heads = channels // num_head_channels
|
|
self.norm = normalization(channels)
|
|
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
|
|
if use_new_attention_order:
|
|
# split qkv before split heads
|
|
self.attention = QKVAttention(self.num_heads)
|
|
else:
|
|
# split heads before split qkv
|
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
|
|
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
|
|
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
|
|
|
def forward(self, x, mask=None, qk_bias=None):
|
|
if self.do_checkpoint:
|
|
if mask is None:
|
|
if qk_bias is None:
|
|
return checkpoint(self._forward, x)
|
|
else:
|
|
assert False, 'unsupported: qk_bias but no mask'
|
|
else:
|
|
if qk_bias is None:
|
|
return checkpoint(self._forward, x, mask)
|
|
else:
|
|
return checkpoint(self._forward, x, mask, qk_bias)
|
|
else:
|
|
return self._forward(x, mask)
|
|
|
|
def _forward(self, x, mask=None, qk_bias=0):
|
|
b, c, *spatial = x.shape
|
|
if mask is not None:
|
|
if len(mask.shape) == 2:
|
|
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
|
if mask.shape[1] != x.shape[-1]:
|
|
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
|
|
|
x = x.reshape(b, c, -1)
|
|
x = self.norm(x)
|
|
if self.do_activation:
|
|
x = F.silu(x, inplace=True)
|
|
qkv = self.qkv(x)
|
|
h = self.attention(qkv, mask, qk_bias)
|
|
h = self.proj_out(h)
|
|
xp = self.x_proj(x)
|
|
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
|
|
|
|
|
class QKVAttentionLegacy(nn.Module):
|
|
"""
|
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv, mask=None, qk_bias=0):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = torch.einsum(
|
|
"bct,bcs->bts", q * scale, k * scale
|
|
) # More stable with f16 than dividing afterwards
|
|
weight = weight + qk_bias
|
|
if mask is not None:
|
|
mask = mask.repeat(self.n_heads, 1, 1)
|
|
weight[mask.logical_not()] = -torch.inf
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
|
|
|
return a.reshape(bs, -1, length)
|
|
|
|
|
|
class QKVAttention(nn.Module):
|
|
"""
|
|
A module which performs QKV attention and splits in a different order.
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv, mask=None, qk_bias=0):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = torch.einsum(
|
|
"bct,bcs->bts",
|
|
(q * scale).view(bs * self.n_heads, ch, length),
|
|
(k * scale).view(bs * self.n_heads, ch, length),
|
|
) # More stable with f16 than dividing afterwards
|
|
if mask is not None:
|
|
mask = mask.repeat(self.n_heads, 1, 1)
|
|
weight[mask.logical_not()] = -torch.inf
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
return a.reshape(bs, -1, length)
|
|
|
|
|
|
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
|
"""Warp an image or feature map with optical flow
|
|
Args:
|
|
x (Tensor): size (N, C, H, W)
|
|
flow (Tensor): size (N, H, W, 2), normal value
|
|
interp_mode (str): 'nearest' or 'bilinear'
|
|
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
|
|
|
Returns:
|
|
Tensor: warped image or feature map
|
|
"""
|
|
assert x.size()[-2:] == flow.size()[1:3]
|
|
B, C, H, W = x.size()
|
|
# mesh grid
|
|
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
|
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
|
grid.requires_grad = False
|
|
grid = grid.type_as(x)
|
|
vgrid = grid + flow
|
|
# scale grid to [-1,1]
|
|
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
|
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
|
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
|
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
|
return output
|
|
|
|
|
|
class PixelUnshuffle(nn.Module):
|
|
def __init__(self, reduction_factor):
|
|
super(PixelUnshuffle, self).__init__()
|
|
self.r = reduction_factor
|
|
|
|
def forward(self, x):
|
|
(b, f, w, h) = x.shape
|
|
x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r)
|
|
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r)
|
|
return x
|
|
|
|
|
|
# simply define a silu function
|
|
def silu(input):
|
|
'''
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
SiLU(x) = x * sigmoid(x)
|
|
'''
|
|
return input * torch.sigmoid(input)
|
|
|
|
# create a class wrapper from PyTorch nn.Module, so
|
|
# the function now can be easily used in models
|
|
class SiLU(nn.Module):
|
|
'''
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
SiLU(x) = x * sigmoid(x)
|
|
Shape:
|
|
- Input: (N, *) where * means, any number of additional
|
|
dimensions
|
|
- Output: (N, *), same shape as the input
|
|
References:
|
|
- Related paper:
|
|
https://arxiv.org/pdf/1606.08415.pdf
|
|
Examples:
|
|
>>> m = silu()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
'''
|
|
def __init__(self):
|
|
'''
|
|
Init method.
|
|
'''
|
|
super().__init__() # init the base class
|
|
|
|
def forward(self, input):
|
|
'''
|
|
Forward pass of the function.
|
|
'''
|
|
return silu(input)
|
|
|
|
|
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnRelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
|
|
super(ConvBnRelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if activation:
|
|
self.relu = nn.ReLU()
|
|
else:
|
|
self.relu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.relu:
|
|
return self.relu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnSilu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
|
super(ConvBnSilu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if activation:
|
|
self.silu = SiLU()
|
|
else:
|
|
self.silu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.silu:
|
|
return self.silu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnLelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
|
super(ConvBnLelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if activation:
|
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
|
else:
|
|
self.lelu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.lelu:
|
|
return self.lelu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvGnLelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
|
super(ConvGnLelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
else:
|
|
self.gn = None
|
|
if activation:
|
|
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
|
else:
|
|
self.lelu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.gn:
|
|
x = self.gn(x)
|
|
if self.lelu:
|
|
return self.lelu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvGnSilu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
|
|
super(ConvGnSilu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
else:
|
|
self.gn = None
|
|
if activation:
|
|
self.silu = SiLU()
|
|
else:
|
|
self.silu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, convnd):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.gn:
|
|
x = self.gn(x)
|
|
if self.silu:
|
|
return self.silu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnRelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
|
super(ConvBnRelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if norm:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if activation:
|
|
self.relu = nn.ReLU()
|
|
else:
|
|
self.relu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.relu:
|
|
return self.relu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
# Simple way to chain multiple conv->act->norms together in an intuitive way.
|
|
class MultiConvBlock(nn.Module):
|
|
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):
|
|
assert depth >= 2
|
|
super(MultiConvBlock, self).__init__()
|
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
|
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
|
|
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float))
|
|
self.bias = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x, noise=None):
|
|
if noise is not None:
|
|
noise = noise * self.noise_scale
|
|
x = x + noise
|
|
for m in self.bnconvs:
|
|
x = m.forward(x)
|
|
return x * self.scale + self.bias
|
|
|
|
|
|
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
|
# along with the feature representation.
|
|
class ExpansionBlock(nn.Module):
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
|
|
super(ExpansionBlock, self).__init__()
|
|
if filters_out is None:
|
|
filters_out = filters_in // 2
|
|
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
|
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
|
self.conjoin = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=False)
|
|
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
|
|
|
|
# input is the feature signal with shape (b, f, w, h)
|
|
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
|
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
|
def forward(self, input, passthrough):
|
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
|
x = self.decimate(x)
|
|
p = self.process_passthrough(passthrough)
|
|
x = self.conjoin(torch.cat([x, p], dim=1))
|
|
return self.process(x)
|
|
|
|
|
|
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
|
# along with the feature representation.
|
|
# Differs from ExpansionBlock because it performs all processing in 2xfilter space and decimates at the last step.
|
|
class ExpansionBlock2(nn.Module):
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
|
|
super(ExpansionBlock2, self).__init__()
|
|
if filters_out is None:
|
|
filters_out = filters_in // 2
|
|
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
|
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
|
self.conjoin = block(filters_out*2, filters_out*2, kernel_size=3, bias=False, activation=True, norm=False)
|
|
self.reduce = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
|
|
|
|
# input is the feature signal with shape (b, f, w, h)
|
|
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
|
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
|
def forward(self, input, passthrough):
|
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
|
x = self.decimate(x)
|
|
p = self.process_passthrough(passthrough)
|
|
x = self.conjoin(torch.cat([x, p], dim=1))
|
|
return self.reduce(x)
|
|
|
|
|
|
# Similar to ExpansionBlock2 but does not upsample.
|
|
class ConjoinBlock(nn.Module):
|
|
def __init__(self, filters_in, filters_out=None, filters_pt=None, block=ConvGnSilu, norm=True):
|
|
super(ConjoinBlock, self).__init__()
|
|
if filters_out is None:
|
|
filters_out = filters_in
|
|
if filters_pt is None:
|
|
filters_pt = filters_in
|
|
self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm)
|
|
self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
|
|
|
|
def forward(self, input, passthrough):
|
|
x = torch.cat([input, passthrough], dim=1)
|
|
x = self.process(x)
|
|
return self.decimate(x)
|
|
|
|
|
|
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
|
class ReferenceJoinBlock(nn.Module):
|
|
def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3, join=True):
|
|
super(ReferenceJoinBlock, self).__init__()
|
|
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=kernel_size, depth=depth,
|
|
scale_init=residual_weight_init_factor, norm=False,
|
|
weight_init_factor=residual_weight_init_factor)
|
|
if join:
|
|
self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
|
|
else:
|
|
self.join_conv = None
|
|
|
|
def forward(self, x, ref):
|
|
joined = torch.cat([x, ref], dim=1)
|
|
branch = self.branch(joined)
|
|
if self.join_conv is not None:
|
|
return self.join_conv(x + branch), torch.std(branch)
|
|
else:
|
|
return x + branch, torch.std(branch)
|
|
|
|
|
|
# Basic convolutional upsampling block that uses interpolate.
|
|
class UpconvBlock(nn.Module):
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
|
super(UpconvBlock, self).__init__()
|
|
self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
return self.process(x)
|
|
|
|
|
|
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
|
|
class FinalUpsampleBlock2x(nn.Module):
|
|
def __init__(self, nf, block=ConvGnLelu, out_nc=3, scale=2):
|
|
super(FinalUpsampleBlock2x, self).__init__()
|
|
if scale == 2:
|
|
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
|
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
|
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
|
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
|
else:
|
|
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
|
UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True),
|
|
block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True),
|
|
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
|
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
|
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
|
|
|
def forward(self, x):
|
|
return self.chain(x)
|
|
|
|
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
|
def gather_2d(input, index):
|
|
b, c, h, w = input.shape
|
|
nodim = input.view(b, c, h * w)
|
|
ind_nd = index[:, 0]*w + index[:, 1]
|
|
ind_nd = ind_nd.unsqueeze(1)
|
|
ind_nd = ind_nd.repeat((1, c))
|
|
ind_nd = ind_nd.unsqueeze(2)
|
|
result = torch.gather(nodim, dim=2, index=ind_nd)
|
|
result = result.squeeze()
|
|
if b == 1:
|
|
result = result.unsqueeze(0)
|
|
return result
|