2021-10-18 04:51:17 +00:00
|
|
|
import math
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
2023-03-21 15:39:28 +00:00
|
|
|
import torch.nn.init as init
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
import dlas.torch_intermediary as ml
|
|
|
|
from dlas.utils.util import checkpoint
|
2022-01-22 15:22:57 +00:00
|
|
|
|
2021-10-21 03:19:25 +00:00
|
|
|
|
|
|
|
def exists(val):
|
|
|
|
return val is not None
|
|
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
return val if exists(val) else d
|
|
|
|
|
|
|
|
|
|
|
|
def l2norm(t):
|
2023-03-21 15:39:28 +00:00
|
|
|
return F.normalize(t, p=2, dim=-1)
|
2021-10-21 03:19:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay):
|
2023-03-21 15:39:28 +00:00
|
|
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
2021-10-21 03:19:25 +00:00
|
|
|
|
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
def laplace_smoothing(x, n_categories, eps=1e-5):
|
2021-10-21 03:19:25 +00:00
|
|
|
return (x + eps) / (x.sum() + n_categories * eps)
|
|
|
|
|
|
|
|
|
|
|
|
def sample_vectors(samples, num):
|
|
|
|
num_samples, device = samples.shape[0], samples.device
|
|
|
|
|
|
|
|
if num_samples >= num:
|
2023-03-21 15:39:28 +00:00
|
|
|
indices = torch.randperm(num_samples, device=device)[:num]
|
2021-10-21 03:19:25 +00:00
|
|
|
else:
|
2023-03-21 15:39:28 +00:00
|
|
|
indices = torch.randint(0, num_samples, (num,), device=device)
|
2021-10-21 03:19:25 +00:00
|
|
|
|
|
|
|
return samples[indices]
|
|
|
|
|
|
|
|
|
2020-10-27 16:25:31 +00:00
|
|
|
def kaiming_init(module,
|
|
|
|
a=0,
|
|
|
|
mode='fan_out',
|
|
|
|
nonlinearity='relu',
|
|
|
|
bias=0,
|
|
|
|
distribution='normal'):
|
|
|
|
assert distribution in ['uniform', 'normal']
|
|
|
|
if distribution == 'uniform':
|
|
|
|
nn.init.kaiming_uniform_(
|
|
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
|
|
else:
|
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
2021-10-21 03:19:25 +00:00
|
|
|
|
2020-04-29 21:17:43 +00:00
|
|
|
def pixel_norm(x, epsilon=1e-8):
|
|
|
|
return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2021-10-21 03:19:25 +00:00
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
def initialize_weights(net_l, scale=1):
|
|
|
|
if not isinstance(net_l, list):
|
|
|
|
net_l = [net_l]
|
|
|
|
for net in net_l:
|
|
|
|
for m in net.modules():
|
2020-06-16 03:32:03 +00:00
|
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
|
2019-08-23 13:42:47 +00:00
|
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
|
|
m.weight.data *= scale # for residual block
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
2023-02-23 02:42:17 +00:00
|
|
|
elif isinstance(m, ml.Linear):
|
2019-08-23 13:42:47 +00:00
|
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
|
|
m.weight.data *= scale
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
init.constant_(m.weight, 1)
|
|
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
2020-10-27 16:25:31 +00:00
|
|
|
def make_layer(block, num_blocks, **kwarg):
|
|
|
|
"""Make layers by stacking the same blocks.
|
|
|
|
Args:
|
|
|
|
block (nn.module): nn.module class for basic block.
|
|
|
|
num_blocks (int): number of blocks.
|
|
|
|
Returns:
|
|
|
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
|
|
|
"""
|
2019-08-23 13:42:47 +00:00
|
|
|
layers = []
|
2020-10-27 16:25:31 +00:00
|
|
|
for _ in range(num_blocks):
|
|
|
|
layers.append(block(**kwarg))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
|
|
def default_init_weights(module, scale=1):
|
|
|
|
"""Initialize network weights.
|
|
|
|
Args:
|
|
|
|
modules (nn.Module): Modules to be initialized.
|
|
|
|
scale (float): Scale initialized weights, especially for residual
|
|
|
|
blocks.
|
|
|
|
"""
|
|
|
|
for m in module.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
|
|
m.weight.data *= scale
|
2023-02-23 02:42:17 +00:00
|
|
|
elif isinstance(m, ml.Linear):
|
2020-10-27 16:25:31 +00:00
|
|
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
|
|
|
m.weight.data *= scale
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2020-04-28 17:48:05 +00:00
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
|
|
class SiLU(nn.Module):
|
|
|
|
def forward(self, x):
|
2021-10-21 03:19:25 +00:00
|
|
|
return x * torch.sigmoid(x)
|
2020-04-28 17:48:05 +00:00
|
|
|
|
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
class GroupNorm32(nn.GroupNorm):
|
2020-04-28 17:48:05 +00:00
|
|
|
def forward(self, x):
|
2021-10-18 04:51:17 +00:00
|
|
|
return super().forward(x.float()).type(x.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Create a 1D, 2D, or 3D convolution module.
|
|
|
|
"""
|
|
|
|
if dims == 1:
|
|
|
|
return nn.Conv1d(*args, **kwargs)
|
|
|
|
elif dims == 2:
|
|
|
|
return nn.Conv2d(*args, **kwargs)
|
|
|
|
elif dims == 3:
|
|
|
|
return nn.Conv3d(*args, **kwargs)
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
|
|
|
|
|
|
def linear(*args, **kwargs):
|
|
|
|
"""
|
|
|
|
Create a linear module.
|
|
|
|
"""
|
2023-02-23 02:42:17 +00:00
|
|
|
return ml.Linear(*args, **kwargs)
|
2021-10-18 04:51:17 +00:00
|
|
|
|
2020-04-28 17:48:05 +00:00
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
def avg_pool_nd(dims, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Create a 1D, 2D, or 3D average pooling module.
|
|
|
|
"""
|
|
|
|
if dims == 1:
|
|
|
|
return nn.AvgPool1d(*args, **kwargs)
|
|
|
|
elif dims == 2:
|
|
|
|
return nn.AvgPool2d(*args, **kwargs)
|
|
|
|
elif dims == 3:
|
|
|
|
return nn.AvgPool3d(*args, **kwargs)
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
|
|
|
|
|
|
def update_ema(target_params, source_params, rate=0.99):
|
|
|
|
"""
|
|
|
|
Update target parameters to be closer to those of source parameters using
|
|
|
|
an exponential moving average.
|
|
|
|
|
|
|
|
:param target_params: the target parameter sequence.
|
|
|
|
:param source_params: the source parameter sequence.
|
|
|
|
:param rate: the EMA rate (closer to 1 means slower).
|
|
|
|
"""
|
|
|
|
for targ, src in zip(target_params, source_params):
|
|
|
|
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
2020-04-29 05:00:29 +00:00
|
|
|
|
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
def zero_module(module):
|
|
|
|
"""
|
|
|
|
Zero out the parameters of a module and return it.
|
|
|
|
"""
|
|
|
|
for p in module.parameters():
|
|
|
|
p.detach().zero_()
|
|
|
|
return module
|
|
|
|
|
|
|
|
|
|
|
|
def scale_module(module, scale):
|
|
|
|
"""
|
|
|
|
Scale the parameters of a module and return it.
|
|
|
|
"""
|
|
|
|
for p in module.parameters():
|
|
|
|
p.detach().mul_(scale)
|
|
|
|
return module
|
|
|
|
|
|
|
|
|
|
|
|
def mean_flat(tensor):
|
|
|
|
"""
|
|
|
|
Take the mean over all non-batch dimensions.
|
|
|
|
"""
|
|
|
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
|
|
|
|
|
|
|
|
|
|
|
def normalization(channels):
|
|
|
|
"""
|
|
|
|
Make a standard normalization layer.
|
|
|
|
|
|
|
|
:param channels: number of input channels.
|
|
|
|
:return: an nn.Module for normalization.
|
|
|
|
"""
|
|
|
|
groups = 32
|
|
|
|
if channels <= 16:
|
|
|
|
groups = 8
|
|
|
|
elif channels <= 64:
|
|
|
|
groups = 16
|
|
|
|
while channels % groups != 0:
|
|
|
|
groups = int(groups / 2)
|
|
|
|
assert groups > 2
|
|
|
|
return GroupNorm32(groups, channels)
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionPool2d(nn.Module):
|
|
|
|
"""
|
|
|
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
spacial_dim: int,
|
|
|
|
embed_dim: int,
|
|
|
|
num_heads_channels: int,
|
|
|
|
output_dim: int = None,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.positional_embedding = nn.Parameter(
|
|
|
|
torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
|
|
|
)
|
|
|
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
|
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
|
|
self.num_heads = embed_dim // num_heads_channels
|
|
|
|
self.attention = QKVAttention(self.num_heads)
|
2020-04-29 05:00:29 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2021-10-18 04:51:17 +00:00
|
|
|
b, c, *_spatial = x.shape
|
|
|
|
x = x.reshape(b, c, -1) # NC(HW)
|
|
|
|
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
2023-03-21 15:39:28 +00:00
|
|
|
x = x + self.positional_embedding[None,
|
|
|
|
:, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
2021-10-18 04:51:17 +00:00
|
|
|
x = self.qkv_proj(x)
|
|
|
|
x = self.attention(x)
|
|
|
|
x = self.c_proj(x)
|
|
|
|
return x[:, :, 0]
|
|
|
|
|
|
|
|
|
|
|
|
class TimestepBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
Any module where forward() takes timestep embeddings as a second argument.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def forward(self, x, emb):
|
|
|
|
"""
|
|
|
|
Apply the module to `x` given `emb` timestep embeddings.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
|
|
"""
|
|
|
|
A sequential module that passes timestep embeddings to the children that
|
|
|
|
support it as an extra input.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def forward(self, x, emb):
|
|
|
|
for layer in self:
|
|
|
|
if isinstance(layer, TimestepBlock):
|
|
|
|
x = layer(x, emb)
|
|
|
|
else:
|
|
|
|
x = layer(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
|
|
"""
|
|
|
|
An upsampling layer with an optional convolution.
|
|
|
|
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
|
|
upsampling occurs in the inner-two dimensions.
|
|
|
|
"""
|
|
|
|
|
2022-05-23 15:28:41 +00:00
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
|
2021-10-18 04:51:17 +00:00
|
|
|
super().__init__()
|
|
|
|
self.channels = channels
|
|
|
|
self.out_channels = out_channels or channels
|
|
|
|
self.use_conv = use_conv
|
|
|
|
self.dims = dims
|
2022-05-23 15:28:41 +00:00
|
|
|
self.factor = factor
|
2021-10-18 04:51:17 +00:00
|
|
|
if use_conv:
|
|
|
|
ksize = 3
|
|
|
|
pad = 1
|
|
|
|
if dims == 1:
|
|
|
|
ksize = 5
|
|
|
|
pad = 2
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = conv_nd(dims, self.channels,
|
|
|
|
self.out_channels, ksize, padding=pad)
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2021-10-18 04:51:17 +00:00
|
|
|
assert x.shape[1] == self.channels
|
|
|
|
if self.dims == 3:
|
|
|
|
x = F.interpolate(
|
|
|
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
|
|
)
|
|
|
|
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
|
|
|
if self.use_conv:
|
|
|
|
x = self.conv(x)
|
|
|
|
return x
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
class Downsample(nn.Module):
|
|
|
|
"""
|
|
|
|
A downsampling layer with an optional convolution.
|
2020-11-11 20:56:45 +00:00
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
:param channels: channels in the inputs and outputs.
|
|
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
|
|
downsampling occurs in the inner-two dimensions.
|
|
|
|
"""
|
2020-11-11 20:56:45 +00:00
|
|
|
|
2022-06-16 21:09:47 +00:00
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
|
2021-10-18 04:51:17 +00:00
|
|
|
super().__init__()
|
|
|
|
self.channels = channels
|
|
|
|
self.out_channels = out_channels or channels
|
|
|
|
self.use_conv = use_conv
|
|
|
|
self.dims = dims
|
|
|
|
ksize = 3
|
|
|
|
pad = 1
|
2022-06-16 21:09:47 +00:00
|
|
|
stride = factor
|
2021-10-18 04:51:17 +00:00
|
|
|
if use_conv:
|
|
|
|
self.op = conv_nd(
|
|
|
|
dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
assert self.channels == self.out_channels
|
|
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
2020-11-11 20:56:45 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2021-10-18 04:51:17 +00:00
|
|
|
assert x.shape[1] == self.channels
|
|
|
|
return self.op(x)
|
|
|
|
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
class cGLU(nn.Module):
|
|
|
|
"""
|
|
|
|
Gated GELU for channel-first architectures.
|
|
|
|
"""
|
2023-03-21 15:39:28 +00:00
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def __init__(self, dim_in, dim_out=None):
|
|
|
|
super().__init__()
|
|
|
|
dim_out = dim_in if dim_out is None else dim_out
|
|
|
|
self.proj = nn.Conv1d(dim_in, dim_out * 2, 1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x, gate = self.proj(x).chunk(2, dim=1)
|
|
|
|
return x * F.gelu(gate)
|
|
|
|
|
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
class ResBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
A residual block that can optionally change the number of channels.
|
|
|
|
|
|
|
|
:param channels: the number of input channels.
|
|
|
|
:param emb_channels: the number of timestep embedding channels.
|
|
|
|
:param dropout: the rate of dropout.
|
|
|
|
:param out_channels: if specified, the number of out channels.
|
|
|
|
:param use_conv: if True and out_channels is specified, use a spatial
|
|
|
|
convolution instead of a smaller 1x1 convolution to change the
|
|
|
|
channels in the skip connection.
|
|
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
|
|
:param up: if True, use this block for upsampling.
|
|
|
|
:param down: if True, use this block for downsampling.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
channels,
|
2022-06-08 15:26:59 +00:00
|
|
|
dropout=0,
|
2021-10-18 04:51:17 +00:00
|
|
|
out_channels=None,
|
|
|
|
use_conv=False,
|
|
|
|
dims=2,
|
|
|
|
up=False,
|
|
|
|
down=False,
|
|
|
|
kernel_size=3,
|
2022-07-14 03:26:25 +00:00
|
|
|
checkpointing_enabled=True,
|
2021-10-18 04:51:17 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.channels = channels
|
|
|
|
self.dropout = dropout
|
|
|
|
self.out_channels = out_channels or channels
|
|
|
|
self.use_conv = use_conv
|
2022-07-14 03:26:25 +00:00
|
|
|
self.checkpointing_enabled = checkpointing_enabled
|
2021-10-18 04:51:17 +00:00
|
|
|
padding = 1 if kernel_size == 3 else 2
|
|
|
|
|
|
|
|
self.in_layers = nn.Sequential(
|
|
|
|
normalization(channels),
|
|
|
|
nn.SiLU(),
|
2023-03-21 15:39:28 +00:00
|
|
|
conv_nd(dims, channels, self.out_channels,
|
|
|
|
kernel_size, padding=padding),
|
2021-10-18 04:51:17 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
self.updown = up or down
|
|
|
|
|
|
|
|
if up:
|
2022-05-23 15:28:41 +00:00
|
|
|
self.h_upd = Upsample(channels, use_conv, dims)
|
|
|
|
self.x_upd = Upsample(channels, use_conv, dims)
|
2021-10-18 04:51:17 +00:00
|
|
|
elif down:
|
2022-05-23 15:28:41 +00:00
|
|
|
self.h_upd = Downsample(channels, use_conv, dims)
|
|
|
|
self.x_upd = Downsample(channels, use_conv, dims)
|
2021-10-18 04:51:17 +00:00
|
|
|
else:
|
|
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
|
|
|
|
self.out_layers = nn.Sequential(
|
|
|
|
normalization(self.out_channels),
|
|
|
|
nn.SiLU(),
|
|
|
|
nn.Dropout(p=dropout),
|
|
|
|
zero_module(
|
2023-03-21 15:39:28 +00:00
|
|
|
conv_nd(dims, self.out_channels, self.out_channels,
|
|
|
|
kernel_size, padding=padding)
|
2021-10-18 04:51:17 +00:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.out_channels == channels:
|
|
|
|
self.skip_connection = nn.Identity()
|
|
|
|
elif use_conv:
|
|
|
|
self.skip_connection = conv_nd(
|
|
|
|
dims, channels, self.out_channels, kernel_size, padding=padding
|
|
|
|
)
|
|
|
|
else:
|
2023-03-21 15:39:28 +00:00
|
|
|
self.skip_connection = conv_nd(
|
|
|
|
dims, channels, self.out_channels, 1)
|
2021-10-18 04:51:17 +00:00
|
|
|
|
2022-05-05 02:29:23 +00:00
|
|
|
def forward(self, x):
|
2021-10-18 04:51:17 +00:00
|
|
|
"""
|
|
|
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
|
|
|
|
|
|
:param x: an [N x C x ...] Tensor of features.
|
|
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
|
|
"""
|
2022-07-14 03:26:25 +00:00
|
|
|
if self.checkpointing_enabled:
|
|
|
|
return checkpoint(
|
|
|
|
self._forward, x
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return self._forward(x)
|
2021-10-18 04:51:17 +00:00
|
|
|
|
|
|
|
def _forward(self, x):
|
|
|
|
if self.updown:
|
|
|
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
|
|
h = in_rest(x)
|
|
|
|
h = self.h_upd(h)
|
|
|
|
x = self.x_upd(x)
|
|
|
|
h = in_conv(h)
|
|
|
|
else:
|
|
|
|
h = self.in_layers(x)
|
|
|
|
h = self.out_layers(h)
|
|
|
|
return self.skip_connection(x) + h
|
|
|
|
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def build_local_attention_mask(n, l, fixed_region=0):
|
2022-07-19 20:59:43 +00:00
|
|
|
"""
|
|
|
|
Builds an attention mask that focuses attention on local region
|
|
|
|
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
|
|
|
Args:
|
|
|
|
n: Size of returned matrix (maximum sequence size)
|
|
|
|
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
|
|
|
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
|
|
|
Returns:
|
|
|
|
A mask that can be applied to AttentionBlock to achieve local attention.
|
|
|
|
"""
|
|
|
|
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
2023-03-21 15:39:28 +00:00
|
|
|
o = torch.arange(0, n)
|
|
|
|
c = o.unsqueeze(-1).repeat(1, n)
|
|
|
|
r = o.unsqueeze(0).repeat(n, 1)
|
|
|
|
localized = ((-(r-c).abs())+l).clamp(0, l-1) / (l-1)
|
2022-07-19 20:59:43 +00:00
|
|
|
localized[:fixed_region] = 1
|
|
|
|
localized[:, :fixed_region] = 1
|
|
|
|
mask = localized > 0
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
def test_local_attention_mask():
|
2023-03-21 15:39:28 +00:00
|
|
|
print(build_local_attention_mask(9, 4, 1))
|
2022-07-19 20:59:43 +00:00
|
|
|
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
class RelativeQKBias(nn.Module):
|
|
|
|
"""
|
|
|
|
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
|
|
|
|
the distance from the given element.
|
2022-07-25 05:43:25 +00:00
|
|
|
|
|
|
|
If symmetric=False, a different bias is applied to each side of the input element, otherwise the bias is symmetric.
|
2022-07-21 05:28:29 +00:00
|
|
|
"""
|
2023-03-21 15:39:28 +00:00
|
|
|
|
2022-07-25 05:43:25 +00:00
|
|
|
def __init__(self, l, max_positions=4000, symmetric=True):
|
2022-07-21 05:28:29 +00:00
|
|
|
super().__init__()
|
2022-07-25 05:43:25 +00:00
|
|
|
if symmetric:
|
|
|
|
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
2023-03-21 15:39:28 +00:00
|
|
|
o = torch.arange(0, max_positions)
|
|
|
|
c = o.unsqueeze(-1).repeat(1, max_positions)
|
|
|
|
r = o.unsqueeze(0).repeat(max_positions, 1)
|
|
|
|
M = ((-(r-c).abs())+l).clamp(0, l)
|
2022-07-25 05:43:25 +00:00
|
|
|
else:
|
|
|
|
self.emb = nn.Parameter(torch.randn(l*2+2) * .01)
|
2023-03-21 15:39:28 +00:00
|
|
|
a = torch.arange(0, max_positions)
|
2022-07-25 05:43:25 +00:00
|
|
|
c = a.unsqueeze(-1) - a
|
|
|
|
m = (c >= -l).logical_and(c <= l)
|
|
|
|
M = (l+c+1)*m
|
2022-07-21 05:28:29 +00:00
|
|
|
self.register_buffer('M', M, persistent=False)
|
|
|
|
|
|
|
|
def forward(self, n):
|
2022-07-21 06:43:03 +00:00
|
|
|
# Ideally, I'd return this:
|
|
|
|
# return self.emb[self.M[:n, :n]].view(1,n,n)
|
|
|
|
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
|
|
|
|
# So, enter this horrible, equivalent mess:
|
2023-03-21 15:39:28 +00:00
|
|
|
return torch.gather(self.emb.unsqueeze(-1).repeat(1, n), 0, self.M[:n, :n]).view(1, n, n)
|
2022-07-21 05:28:29 +00:00
|
|
|
|
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
class AttentionBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
channels,
|
|
|
|
num_heads=1,
|
|
|
|
num_head_channels=-1,
|
2022-07-18 22:36:22 +00:00
|
|
|
out_channels=None,
|
2021-10-18 04:51:17 +00:00
|
|
|
use_new_attention_order=False,
|
|
|
|
do_checkpoint=True,
|
2022-06-06 15:13:47 +00:00
|
|
|
do_activation=False,
|
2021-10-18 04:51:17 +00:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.channels = channels
|
2022-07-18 22:36:22 +00:00
|
|
|
out_channels = channels if out_channels is None else out_channels
|
2021-10-18 04:51:17 +00:00
|
|
|
self.do_checkpoint = do_checkpoint
|
2022-06-06 15:13:47 +00:00
|
|
|
self.do_activation = do_activation
|
2021-10-18 04:51:17 +00:00
|
|
|
if num_head_channels == -1:
|
|
|
|
self.num_heads = num_heads
|
|
|
|
else:
|
|
|
|
assert (
|
|
|
|
channels % num_head_channels == 0
|
|
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
|
|
self.num_heads = channels // num_head_channels
|
|
|
|
self.norm = normalization(channels)
|
2022-07-18 22:36:22 +00:00
|
|
|
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
|
2021-10-18 04:51:17 +00:00
|
|
|
if use_new_attention_order:
|
|
|
|
# split qkv before split heads
|
|
|
|
self.attention = QKVAttention(self.num_heads)
|
|
|
|
else:
|
|
|
|
# split heads before split qkv
|
|
|
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
|
|
|
2023-03-21 15:39:28 +00:00
|
|
|
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(
|
|
|
|
1, channels, out_channels, 1)
|
2022-07-18 22:36:22 +00:00
|
|
|
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
2021-10-18 04:51:17 +00:00
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def forward(self, x, mask=None, qk_bias=None):
|
2021-10-18 04:51:17 +00:00
|
|
|
if self.do_checkpoint:
|
2022-07-21 05:28:29 +00:00
|
|
|
if mask is None:
|
|
|
|
if qk_bias is None:
|
|
|
|
return checkpoint(self._forward, x)
|
|
|
|
else:
|
|
|
|
assert False, 'unsupported: qk_bias but no mask'
|
2022-01-22 15:22:57 +00:00
|
|
|
else:
|
2022-07-21 05:28:29 +00:00
|
|
|
if qk_bias is None:
|
|
|
|
return checkpoint(self._forward, x, mask)
|
|
|
|
else:
|
|
|
|
return checkpoint(self._forward, x, mask, qk_bias)
|
2021-10-18 04:51:17 +00:00
|
|
|
else:
|
|
|
|
return self._forward(x, mask)
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def _forward(self, x, mask=None, qk_bias=0):
|
2021-10-18 04:51:17 +00:00
|
|
|
b, c, *spatial = x.shape
|
2022-07-19 20:59:43 +00:00
|
|
|
if mask is not None:
|
|
|
|
if len(mask.shape) == 2:
|
2023-03-21 15:39:28 +00:00
|
|
|
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
2022-07-19 20:59:43 +00:00
|
|
|
if mask.shape[1] != x.shape[-1]:
|
|
|
|
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
2022-07-19 19:30:05 +00:00
|
|
|
|
2021-10-18 04:51:17 +00:00
|
|
|
x = x.reshape(b, c, -1)
|
2022-06-06 15:13:47 +00:00
|
|
|
x = self.norm(x)
|
|
|
|
if self.do_activation:
|
|
|
|
x = F.silu(x, inplace=True)
|
|
|
|
qkv = self.qkv(x)
|
2022-07-21 05:28:29 +00:00
|
|
|
h = self.attention(qkv, mask, qk_bias)
|
2021-10-18 04:51:17 +00:00
|
|
|
h = self.proj_out(h)
|
2022-07-18 22:36:22 +00:00
|
|
|
xp = self.x_proj(x)
|
|
|
|
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
2021-10-18 04:51:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
class QKVAttentionLegacy(nn.Module):
|
|
|
|
"""
|
|
|
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, n_heads):
|
|
|
|
super().__init__()
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def forward(self, qkv, mask=None, qk_bias=0):
|
2021-10-18 04:51:17 +00:00
|
|
|
"""
|
|
|
|
Apply QKV attention.
|
|
|
|
|
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
|
|
"""
|
|
|
|
bs, width, length = qkv.shape
|
|
|
|
assert width % (3 * self.n_heads) == 0
|
|
|
|
ch = width // (3 * self.n_heads)
|
2023-03-21 15:39:28 +00:00
|
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
|
|
|
|
length).split(ch, dim=1)
|
2021-10-18 04:51:17 +00:00
|
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
|
|
weight = torch.einsum(
|
|
|
|
"bct,bcs->bts", q * scale, k * scale
|
|
|
|
) # More stable with f16 than dividing afterwards
|
2022-07-21 05:28:29 +00:00
|
|
|
weight = weight + qk_bias
|
2021-10-18 04:51:17 +00:00
|
|
|
if mask is not None:
|
2022-07-19 19:30:05 +00:00
|
|
|
mask = mask.repeat(self.n_heads, 1, 1)
|
|
|
|
weight[mask.logical_not()] = -torch.inf
|
|
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
2021-10-18 04:51:17 +00:00
|
|
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
|
|
|
|
|
|
|
return a.reshape(bs, -1, length)
|
|
|
|
|
|
|
|
|
|
|
|
class QKVAttention(nn.Module):
|
|
|
|
"""
|
|
|
|
A module which performs QKV attention and splits in a different order.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, n_heads):
|
|
|
|
super().__init__()
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
2022-07-21 05:28:29 +00:00
|
|
|
def forward(self, qkv, mask=None, qk_bias=0):
|
2021-10-18 04:51:17 +00:00
|
|
|
"""
|
|
|
|
Apply QKV attention.
|
|
|
|
|
|
|
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
|
|
"""
|
|
|
|
bs, width, length = qkv.shape
|
|
|
|
assert width % (3 * self.n_heads) == 0
|
|
|
|
ch = width // (3 * self.n_heads)
|
|
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
|
|
weight = torch.einsum(
|
|
|
|
"bct,bcs->bts",
|
|
|
|
(q * scale).view(bs * self.n_heads, ch, length),
|
|
|
|
(k * scale).view(bs * self.n_heads, ch, length),
|
|
|
|
) # More stable with f16 than dividing afterwards
|
|
|
|
if mask is not None:
|
2022-07-19 19:30:05 +00:00
|
|
|
mask = mask.repeat(self.n_heads, 1, 1)
|
|
|
|
weight[mask.logical_not()] = -torch.inf
|
2021-10-18 04:51:17 +00:00
|
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
2023-03-21 15:39:28 +00:00
|
|
|
a = torch.einsum("bts,bcs->bct", weight,
|
|
|
|
v.reshape(bs * self.n_heads, ch, length))
|
2021-10-18 04:51:17 +00:00
|
|
|
return a.reshape(bs, -1, length)
|
2020-11-11 20:56:45 +00:00
|
|
|
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
|
|
|
"""Warp an image or feature map with optical flow
|
|
|
|
Args:
|
|
|
|
x (Tensor): size (N, C, H, W)
|
|
|
|
flow (Tensor): size (N, H, W, 2), normal value
|
|
|
|
interp_mode (str): 'nearest' or 'bilinear'
|
|
|
|
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: warped image or feature map
|
|
|
|
"""
|
|
|
|
assert x.size()[-2:] == flow.size()[1:3]
|
|
|
|
B, C, H, W = x.size()
|
|
|
|
# mesh grid
|
|
|
|
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
|
|
|
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
|
|
|
grid.requires_grad = False
|
|
|
|
grid = grid.type_as(x)
|
|
|
|
vgrid = grid + flow
|
|
|
|
# scale grid to [-1,1]
|
|
|
|
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
|
|
|
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
|
|
|
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
2023-03-21 15:39:28 +00:00
|
|
|
output = F.grid_sample(
|
|
|
|
x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
2019-08-23 13:42:47 +00:00
|
|
|
return output
|
2020-06-13 17:37:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
class PixelUnshuffle(nn.Module):
|
|
|
|
def __init__(self, reduction_factor):
|
|
|
|
super(PixelUnshuffle, self).__init__()
|
|
|
|
self.r = reduction_factor
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
(b, f, w, h) = x.shape
|
|
|
|
x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r)
|
2023-03-21 15:39:28 +00:00
|
|
|
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(
|
|
|
|
b, f * (self.r ** 2), w // self.r, h // self.r)
|
2020-06-13 17:37:27 +00:00
|
|
|
return x
|
|
|
|
|
2020-07-03 18:06:38 +00:00
|
|
|
|
2020-07-05 19:39:08 +00:00
|
|
|
# simply define a silu function
|
|
|
|
def silu(input):
|
|
|
|
'''
|
|
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
|
|
SiLU(x) = x * sigmoid(x)
|
|
|
|
'''
|
2020-07-05 23:28:00 +00:00
|
|
|
return input * torch.sigmoid(input)
|
2020-07-05 19:39:08 +00:00
|
|
|
|
|
|
|
# create a class wrapper from PyTorch nn.Module, so
|
|
|
|
# the function now can be easily used in models
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-05 19:39:08 +00:00
|
|
|
class SiLU(nn.Module):
|
|
|
|
'''
|
|
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
|
|
SiLU(x) = x * sigmoid(x)
|
|
|
|
Shape:
|
|
|
|
- Input: (N, *) where * means, any number of additional
|
|
|
|
dimensions
|
|
|
|
- Output: (N, *), same shape as the input
|
|
|
|
References:
|
|
|
|
- Related paper:
|
|
|
|
https://arxiv.org/pdf/1606.08415.pdf
|
|
|
|
Examples:
|
|
|
|
>>> m = silu()
|
|
|
|
>>> input = torch.randn(2)
|
|
|
|
>>> output = m(input)
|
|
|
|
'''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
2020-07-05 19:39:08 +00:00
|
|
|
def __init__(self):
|
|
|
|
'''
|
|
|
|
Init method.
|
|
|
|
'''
|
2023-03-21 15:39:28 +00:00
|
|
|
super().__init__() # init the base class
|
2020-07-05 19:39:08 +00:00
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
'''
|
|
|
|
Forward pass of the function.
|
|
|
|
'''
|
2020-07-05 23:28:00 +00:00
|
|
|
return silu(input)
|
2020-07-05 19:39:08 +00:00
|
|
|
|
|
|
|
|
2020-07-03 18:06:38 +00:00
|
|
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-03 18:06:38 +00:00
|
|
|
class ConvBnRelu(nn.Module):
|
2020-07-10 21:53:41 +00:00
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
|
2020-07-03 18:06:38 +00:00
|
|
|
super(ConvBnRelu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-07-10 21:53:41 +00:00
|
|
|
if norm:
|
2020-07-03 18:06:38 +00:00
|
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
|
|
else:
|
|
|
|
self.bn = None
|
2020-07-10 21:53:41 +00:00
|
|
|
if activation:
|
2020-07-03 18:06:38 +00:00
|
|
|
self.relu = nn.ReLU()
|
|
|
|
else:
|
|
|
|
self.relu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
2023-03-21 15:39:28 +00:00
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
2020-07-03 18:06:38 +00:00
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.bn:
|
|
|
|
x = self.bn(x)
|
|
|
|
if self.relu:
|
|
|
|
return self.relu(x)
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
2020-07-05 19:39:08 +00:00
|
|
|
|
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-05 19:39:08 +00:00
|
|
|
class ConvBnSilu(nn.Module):
|
2020-07-10 21:53:41 +00:00
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
2020-07-05 19:39:08 +00:00
|
|
|
super(ConvBnSilu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-07-10 21:53:41 +00:00
|
|
|
if norm:
|
2020-07-05 19:39:08 +00:00
|
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
|
|
else:
|
|
|
|
self.bn = None
|
2020-07-10 21:53:41 +00:00
|
|
|
if activation:
|
2020-07-05 19:39:08 +00:00
|
|
|
self.silu = SiLU()
|
|
|
|
else:
|
|
|
|
self.silu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
2023-03-21 15:39:28 +00:00
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
2020-07-09 23:34:51 +00:00
|
|
|
m.weight.data *= weight_init_factor
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
2020-07-05 19:39:08 +00:00
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.bn:
|
|
|
|
x = self.bn(x)
|
|
|
|
if self.silu:
|
|
|
|
return self.silu(x)
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-07-03 18:06:38 +00:00
|
|
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-03 18:06:38 +00:00
|
|
|
class ConvBnLelu(nn.Module):
|
2020-07-10 21:53:41 +00:00
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
2020-07-03 18:06:38 +00:00
|
|
|
super(ConvBnLelu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-07-10 21:53:41 +00:00
|
|
|
if norm:
|
2020-07-03 18:06:38 +00:00
|
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
|
|
else:
|
|
|
|
self.bn = None
|
2020-07-10 21:53:41 +00:00
|
|
|
if activation:
|
2020-07-03 18:06:38 +00:00
|
|
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
|
|
|
else:
|
|
|
|
self.lelu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
2020-07-09 23:34:51 +00:00
|
|
|
m.weight.data *= weight_init_factor
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
2020-07-03 18:06:38 +00:00
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.bn:
|
|
|
|
x = self.bn(x)
|
|
|
|
if self.lelu:
|
|
|
|
return self.lelu(x)
|
2020-07-07 02:59:59 +00:00
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-07 02:59:59 +00:00
|
|
|
class ConvGnLelu(nn.Module):
|
2020-07-18 20:18:48 +00:00
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
2020-07-07 02:59:59 +00:00
|
|
|
super(ConvGnLelu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-07-10 21:53:41 +00:00
|
|
|
if norm:
|
2020-07-07 02:59:59 +00:00
|
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
|
|
else:
|
|
|
|
self.gn = None
|
2020-07-10 21:53:41 +00:00
|
|
|
if activation:
|
2020-12-19 15:28:03 +00:00
|
|
|
self.lelu = nn.LeakyReLU(negative_slope=.2)
|
2020-07-07 02:59:59 +00:00
|
|
|
else:
|
|
|
|
self.lelu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
2020-07-18 20:18:48 +00:00
|
|
|
m.weight.data *= weight_init_factor
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
2020-07-07 02:59:59 +00:00
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.gn:
|
|
|
|
x = self.gn(x)
|
|
|
|
if self.lelu:
|
|
|
|
return self.lelu(x)
|
2020-07-09 23:34:51 +00:00
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
2020-10-28 21:21:22 +00:00
|
|
|
|
2020-07-09 23:34:51 +00:00
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-07-09 23:34:51 +00:00
|
|
|
class ConvGnSilu(nn.Module):
|
2021-07-27 11:36:17 +00:00
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
|
2020-07-09 23:34:51 +00:00
|
|
|
super(ConvGnSilu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = convnd(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-07-10 21:53:41 +00:00
|
|
|
if norm:
|
2020-07-09 23:34:51 +00:00
|
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
|
|
else:
|
|
|
|
self.gn = None
|
2020-07-10 21:53:41 +00:00
|
|
|
if activation:
|
2020-07-09 23:34:51 +00:00
|
|
|
self.silu = SiLU()
|
|
|
|
else:
|
|
|
|
self.silu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
2021-07-27 11:36:17 +00:00
|
|
|
if isinstance(m, convnd):
|
2023-03-21 15:39:28 +00:00
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
2020-07-09 23:34:51 +00:00
|
|
|
m.weight.data *= weight_init_factor
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.gn:
|
|
|
|
x = self.gn(x)
|
|
|
|
if self.silu:
|
|
|
|
return self.silu(x)
|
2020-07-03 18:06:38 +00:00
|
|
|
else:
|
2020-07-10 21:53:41 +00:00
|
|
|
return x
|
|
|
|
|
2020-09-08 14:17:27 +00:00
|
|
|
|
2020-12-16 00:16:19 +00:00
|
|
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
|
|
|
kernel sizes. '''
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2020-12-16 00:16:19 +00:00
|
|
|
class ConvBnRelu(nn.Module):
|
|
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
|
|
|
super(ConvBnRelu, self).__init__()
|
|
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
|
|
assert kernel_size in padding_map.keys()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
|
|
|
stride, padding_map[kernel_size], bias=bias)
|
2020-12-16 00:16:19 +00:00
|
|
|
if norm:
|
|
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
|
|
else:
|
|
|
|
self.bn = None
|
|
|
|
if activation:
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
else:
|
|
|
|
self.relu = None
|
|
|
|
|
|
|
|
# Init params.
|
|
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, nn.Conv2d):
|
2023-03-21 15:39:28 +00:00
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
2020-12-16 00:16:19 +00:00
|
|
|
m.weight.data *= weight_init_factor
|
|
|
|
if m.bias is not None:
|
|
|
|
m.bias.data.zero_()
|
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.bn:
|
|
|
|
x = self.bn(x)
|
|
|
|
if self.relu:
|
|
|
|
return self.relu(x)
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-09-08 14:17:27 +00:00
|
|
|
# Simple way to chain multiple conv->act->norms together in an intuitive way.
|
|
|
|
class MultiConvBlock(nn.Module):
|
|
|
|
def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth, scale_init=1, norm=False, weight_init_factor=1):
|
|
|
|
assert depth >= 2
|
|
|
|
super(MultiConvBlock, self).__init__()
|
|
|
|
self.noise_scale = nn.Parameter(torch.full((1,), fill_value=.01))
|
|
|
|
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
|
|
|
|
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
|
|
|
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
|
2023-03-21 15:39:28 +00:00
|
|
|
self.scale = nn.Parameter(torch.full(
|
|
|
|
(1,), fill_value=scale_init, dtype=torch.float))
|
2020-09-08 14:17:27 +00:00
|
|
|
self.bias = nn.Parameter(torch.zeros(1))
|
|
|
|
|
|
|
|
def forward(self, x, noise=None):
|
|
|
|
if noise is not None:
|
|
|
|
noise = noise * self.noise_scale
|
|
|
|
x = x + noise
|
|
|
|
for m in self.bnconvs:
|
|
|
|
x = m.forward(x)
|
|
|
|
return x * self.scale + self.bias
|
|
|
|
|
|
|
|
|
2020-07-10 21:53:41 +00:00
|
|
|
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
|
|
|
# along with the feature representation.
|
|
|
|
class ExpansionBlock(nn.Module):
|
2020-07-11 04:57:34 +00:00
|
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
|
2020-07-10 21:53:41 +00:00
|
|
|
super(ExpansionBlock, self).__init__()
|
2020-07-11 04:57:34 +00:00
|
|
|
if filters_out is None:
|
|
|
|
filters_out = filters_in // 2
|
2023-03-21 15:39:28 +00:00
|
|
|
self.decimate = block(
|
|
|
|
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
|
|
|
self.process_passthrough = block(
|
|
|
|
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
|
|
|
self.conjoin = block(filters_out*2, filters_out,
|
|
|
|
kernel_size=3, bias=False, activation=True, norm=False)
|
|
|
|
self.process = block(filters_out, filters_out,
|
|
|
|
kernel_size=3, bias=False, activation=True, norm=True)
|
2020-07-10 21:53:41 +00:00
|
|
|
|
|
|
|
# input is the feature signal with shape (b, f, w, h)
|
|
|
|
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
|
|
|
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
|
|
|
def forward(self, input, passthrough):
|
|
|
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
|
|
|
x = self.decimate(x)
|
|
|
|
p = self.process_passthrough(passthrough)
|
|
|
|
x = self.conjoin(torch.cat([x, p], dim=1))
|
2020-08-03 16:25:37 +00:00
|
|
|
return self.process(x)
|
|
|
|
|
|
|
|
|
2020-08-12 14:45:49 +00:00
|
|
|
# Block that upsamples 2x and reduces incoming filters by 2x. It preserves structure by taking a passthrough feed
|
|
|
|
# along with the feature representation.
|
|
|
|
# Differs from ExpansionBlock because it performs all processing in 2xfilter space and decimates at the last step.
|
|
|
|
class ExpansionBlock2(nn.Module):
|
|
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu):
|
|
|
|
super(ExpansionBlock2, self).__init__()
|
|
|
|
if filters_out is None:
|
|
|
|
filters_out = filters_in // 2
|
2023-03-21 15:39:28 +00:00
|
|
|
self.decimate = block(
|
|
|
|
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
|
|
|
self.process_passthrough = block(
|
|
|
|
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
|
|
|
self.conjoin = block(filters_out*2, filters_out*2,
|
|
|
|
kernel_size=3, bias=False, activation=True, norm=False)
|
|
|
|
self.reduce = block(filters_out*2, filters_out,
|
|
|
|
kernel_size=3, bias=False, activation=True, norm=True)
|
2020-08-12 14:45:49 +00:00
|
|
|
|
|
|
|
# input is the feature signal with shape (b, f, w, h)
|
|
|
|
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
|
|
|
# output is conjoined upsample with shape (b, f/2, w*2, h*2)
|
|
|
|
def forward(self, input, passthrough):
|
|
|
|
x = F.interpolate(input, scale_factor=2, mode="nearest")
|
|
|
|
x = self.decimate(x)
|
|
|
|
p = self.process_passthrough(passthrough)
|
|
|
|
x = self.conjoin(torch.cat([x, p], dim=1))
|
|
|
|
return self.reduce(x)
|
|
|
|
|
|
|
|
|
2020-08-25 17:56:59 +00:00
|
|
|
# Similar to ExpansionBlock2 but does not upsample.
|
2020-08-03 16:25:37 +00:00
|
|
|
class ConjoinBlock(nn.Module):
|
2020-08-25 17:56:59 +00:00
|
|
|
def __init__(self, filters_in, filters_out=None, filters_pt=None, block=ConvGnSilu, norm=True):
|
2020-08-03 16:25:37 +00:00
|
|
|
super(ConjoinBlock, self).__init__()
|
|
|
|
if filters_out is None:
|
|
|
|
filters_out = filters_in
|
2020-08-25 17:56:59 +00:00
|
|
|
if filters_pt is None:
|
|
|
|
filters_pt = filters_in
|
2023-03-21 15:39:28 +00:00
|
|
|
self.process = block(filters_in + filters_pt, filters_in + filters_pt,
|
|
|
|
kernel_size=3, bias=False, activation=True, norm=norm)
|
|
|
|
self.decimate = block(filters_in + filters_pt, filters_out,
|
|
|
|
kernel_size=1, bias=False, activation=False, norm=norm)
|
2020-08-12 23:04:45 +00:00
|
|
|
|
|
|
|
def forward(self, input, passthrough):
|
|
|
|
x = torch.cat([input, passthrough], dim=1)
|
|
|
|
x = self.process(x)
|
|
|
|
return self.decimate(x)
|
|
|
|
|
|
|
|
|
2020-09-08 21:14:23 +00:00
|
|
|
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
|
|
|
class ReferenceJoinBlock(nn.Module):
|
2020-10-12 16:20:55 +00:00
|
|
|
def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3, join=True):
|
2020-09-08 21:14:23 +00:00
|
|
|
super(ReferenceJoinBlock, self).__init__()
|
2020-09-16 02:59:24 +00:00
|
|
|
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=kernel_size, depth=depth,
|
2020-09-10 22:34:41 +00:00
|
|
|
scale_init=residual_weight_init_factor, norm=False,
|
2020-09-08 21:14:23 +00:00
|
|
|
weight_init_factor=residual_weight_init_factor)
|
2020-10-12 16:20:55 +00:00
|
|
|
if join:
|
2023-03-21 15:39:28 +00:00
|
|
|
self.join_conv = block(
|
|
|
|
nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
|
2020-10-12 16:20:55 +00:00
|
|
|
else:
|
|
|
|
self.join_conv = None
|
2020-09-08 21:14:23 +00:00
|
|
|
|
|
|
|
def forward(self, x, ref):
|
|
|
|
joined = torch.cat([x, ref], dim=1)
|
|
|
|
branch = self.branch(joined)
|
2020-10-12 16:20:55 +00:00
|
|
|
if self.join_conv is not None:
|
|
|
|
return self.join_conv(x + branch), torch.std(branch)
|
|
|
|
else:
|
|
|
|
return x + branch, torch.std(branch)
|
2020-09-08 21:14:23 +00:00
|
|
|
|
|
|
|
|
2020-08-03 16:25:37 +00:00
|
|
|
# Basic convolutional upsampling block that uses interpolate.
|
|
|
|
class UpconvBlock(nn.Module):
|
|
|
|
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
|
|
|
super(UpconvBlock, self).__init__()
|
2023-03-21 15:39:28 +00:00
|
|
|
self.process = block(filters_in, filters_out, kernel_size=3,
|
|
|
|
bias=bias, activation=activation, norm=norm)
|
2020-08-03 16:25:37 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
2020-09-07 23:01:48 +00:00
|
|
|
return self.process(x)
|
2020-10-16 05:18:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Scales an image up 2x and performs intermediary processing. Designed to be the final block in an SR network.
|
|
|
|
class FinalUpsampleBlock2x(nn.Module):
|
2020-10-23 15:25:58 +00:00
|
|
|
def __init__(self, nf, block=ConvGnLelu, out_nc=3, scale=2):
|
2020-10-16 05:18:08 +00:00
|
|
|
super(FinalUpsampleBlock2x, self).__init__()
|
2020-10-23 15:25:58 +00:00
|
|
|
if scale == 2:
|
|
|
|
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
2023-03-21 15:39:28 +00:00
|
|
|
UpconvBlock(
|
|
|
|
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
|
|
|
block(nf // 2, nf // 2, kernel_size=3,
|
|
|
|
norm=False, activation=False, bias=True),
|
2020-10-23 15:25:58 +00:00
|
|
|
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
|
|
|
else:
|
|
|
|
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
2023-03-21 15:39:28 +00:00
|
|
|
UpconvBlock(
|
|
|
|
nf, nf, block=block, norm=False, activation=True, bias=True),
|
|
|
|
block(nf, nf, kernel_size=3, norm=False,
|
|
|
|
activation=False, bias=True),
|
|
|
|
UpconvBlock(
|
|
|
|
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
|
|
|
block(nf // 2, nf // 2, kernel_size=3,
|
|
|
|
norm=False, activation=False, bias=True),
|
2020-10-23 15:25:58 +00:00
|
|
|
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
2020-10-16 05:18:08 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.chain(x)
|
2022-03-16 18:04:00 +00:00
|
|
|
|
|
|
|
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
2023-03-21 15:39:28 +00:00
|
|
|
|
|
|
|
|
2022-03-16 18:04:00 +00:00
|
|
|
def gather_2d(input, index):
|
|
|
|
b, c, h, w = input.shape
|
|
|
|
nodim = input.view(b, c, h * w)
|
|
|
|
ind_nd = index[:, 0]*w + index[:, 1]
|
|
|
|
ind_nd = ind_nd.unsqueeze(1)
|
|
|
|
ind_nd = ind_nd.repeat((1, c))
|
|
|
|
ind_nd = ind_nd.unsqueeze(2)
|
|
|
|
result = torch.gather(nodim, dim=2, index=ind_nd)
|
|
|
|
result = result.squeeze()
|
|
|
|
if b == 1:
|
|
|
|
result = result.unsqueeze(0)
|
2022-07-14 03:26:25 +00:00
|
|
|
return result
|