5f2c722a10
Big update to SRG2 architecture to pull in a lot of things that have been learned: - Use group norm instead of batch norm - Initialize the weights on the transformations low like is done in RRDB rather than using the scalar. Models live or die by their early stages, and this ones early stage is pretty weak - Transform multiplexer to use u-net like architecture. - Just use one set of configuration variables instead of a list - flat networks performed fine in this regard.
366 lines
13 KiB
Python
366 lines
13 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
import torch.nn.functional as F
|
|
import torch.nn.utils.spectral_norm as SpectralNorm
|
|
from math import sqrt
|
|
|
|
def pixel_norm(x, epsilon=1e-8):
|
|
return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon)
|
|
|
|
def initialize_weights(net_l, scale=1):
|
|
if not isinstance(net_l, list):
|
|
net_l = [net_l]
|
|
for net in net_l:
|
|
for m in net.modules():
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
m.weight.data *= scale # for residual block
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.Linear):
|
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
m.weight.data *= scale
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
init.constant_(m.weight, 1)
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
def make_layer(block, n_layers, return_layers=False):
|
|
layers = []
|
|
for _ in range(n_layers):
|
|
layers.append(block())
|
|
if return_layers:
|
|
return nn.Sequential(*layers), layers
|
|
else:
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
'''Residual block with BN
|
|
---Conv-BN-ReLU-Conv-+-
|
|
|________________|
|
|
'''
|
|
|
|
def __init__(self, nf=64):
|
|
super(ResidualBlock, self).__init__()
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
self.BN1 = nn.BatchNorm2d(nf)
|
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
self.BN2 = nn.BatchNorm2d(nf)
|
|
|
|
# initialization
|
|
initialize_weights([self.conv1, self.conv2], 0.1)
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.lrelu(self.BN1(self.conv1(x)))
|
|
out = self.BN2(self.conv2(out))
|
|
return identity + out
|
|
|
|
class ResidualBlockSpectralNorm(nn.Module):
|
|
'''Residual block with Spectral Normalization.
|
|
---SpecConv-ReLU-SpecConv-+-
|
|
|________________|
|
|
'''
|
|
|
|
def __init__(self, nf, total_residual_blocks):
|
|
super(ResidualBlockSpectralNorm, self).__init__()
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
|
|
self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
|
|
|
|
initialize_weights([self.conv1, self.conv2], 1)
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.lrelu(self.conv1(x))
|
|
out = self.conv2(out)
|
|
return identity + out
|
|
|
|
class ResidualBlock_noBN(nn.Module):
|
|
'''Residual block w/o BN
|
|
---Conv-ReLU-Conv-+-
|
|
|________________|
|
|
'''
|
|
|
|
def __init__(self, nf=64):
|
|
super(ResidualBlock_noBN, self).__init__()
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
|
|
# initialization
|
|
initialize_weights([self.conv1, self.conv2], 0.1)
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.lrelu(self.conv1(x))
|
|
out = self.conv2(out)
|
|
return identity + out
|
|
|
|
|
|
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
|
"""Warp an image or feature map with optical flow
|
|
Args:
|
|
x (Tensor): size (N, C, H, W)
|
|
flow (Tensor): size (N, H, W, 2), normal value
|
|
interp_mode (str): 'nearest' or 'bilinear'
|
|
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
|
|
|
Returns:
|
|
Tensor: warped image or feature map
|
|
"""
|
|
assert x.size()[-2:] == flow.size()[1:3]
|
|
B, C, H, W = x.size()
|
|
# mesh grid
|
|
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
|
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
|
grid.requires_grad = False
|
|
grid = grid.type_as(x)
|
|
vgrid = grid + flow
|
|
# scale grid to [-1,1]
|
|
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
|
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
|
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
|
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
|
return output
|
|
|
|
|
|
class PixelUnshuffle(nn.Module):
|
|
def __init__(self, reduction_factor):
|
|
super(PixelUnshuffle, self).__init__()
|
|
self.r = reduction_factor
|
|
|
|
def forward(self, x):
|
|
(b, f, w, h) = x.shape
|
|
x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r)
|
|
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r)
|
|
return x
|
|
|
|
|
|
# simply define a silu function
|
|
def silu(input):
|
|
'''
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
SiLU(x) = x * sigmoid(x)
|
|
'''
|
|
return input * torch.sigmoid(input)
|
|
|
|
# create a class wrapper from PyTorch nn.Module, so
|
|
# the function now can be easily used in models
|
|
class SiLU(nn.Module):
|
|
'''
|
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
|
SiLU(x) = x * sigmoid(x)
|
|
Shape:
|
|
- Input: (N, *) where * means, any number of additional
|
|
dimensions
|
|
- Output: (N, *), same shape as the input
|
|
References:
|
|
- Related paper:
|
|
https://arxiv.org/pdf/1606.08415.pdf
|
|
Examples:
|
|
>>> m = silu()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
'''
|
|
def __init__(self):
|
|
'''
|
|
Init method.
|
|
'''
|
|
super().__init__() # init the base class
|
|
|
|
def forward(self, input):
|
|
'''
|
|
Forward pass of the function.
|
|
'''
|
|
return silu(input)
|
|
|
|
|
|
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnRelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, relu=True, bn=True, bias=True):
|
|
super(ConvBnRelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if bn:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if relu:
|
|
self.relu = nn.ReLU()
|
|
else:
|
|
self.relu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.relu:
|
|
return self.relu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnSilu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, bn=True, bias=True, weight_init_factor=1):
|
|
super(ConvBnSilu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if bn:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if silu:
|
|
self.silu = SiLU()
|
|
else:
|
|
self.silu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.silu:
|
|
return self.silu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvBnLelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True, bias=True, weight_init_factor=1):
|
|
super(ConvBnLelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if bn:
|
|
self.bn = nn.BatchNorm2d(filters_out)
|
|
else:
|
|
self.bn = None
|
|
if lelu:
|
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
|
else:
|
|
self.lelu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn:
|
|
x = self.bn(x)
|
|
if self.lelu:
|
|
return self.lelu(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvGnLelu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, gn=True, bias=True, num_groups=8):
|
|
super(ConvGnLelu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if gn:
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
else:
|
|
self.gn = None
|
|
if lelu:
|
|
self.lelu = nn.LeakyReLU(negative_slope=.1)
|
|
else:
|
|
self.lelu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out',
|
|
nonlinearity='leaky_relu' if self.lelu else 'linear')
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.gn:
|
|
x = self.gn(x)
|
|
if self.lelu:
|
|
return self.lelu(x)
|
|
else:
|
|
return x
|
|
|
|
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
|
kernel sizes. '''
|
|
class ConvGnSilu(nn.Module):
|
|
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, silu=True, gn=True, bias=True, num_groups=8, weight_init_factor=1):
|
|
super(ConvGnSilu, self).__init__()
|
|
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
|
assert kernel_size in padding_map.keys()
|
|
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
|
if gn:
|
|
self.gn = nn.GroupNorm(num_groups, filters_out)
|
|
else:
|
|
self.gn = None
|
|
if silu:
|
|
self.silu = SiLU()
|
|
else:
|
|
self.silu = None
|
|
|
|
# Init params.
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
|
m.weight.data *= weight_init_factor
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.gn:
|
|
x = self.gn(x)
|
|
if self.silu:
|
|
return self.silu(x)
|
|
else:
|
|
return x |