259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
|
from collections import OrderedDict
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
####################
|
||
|
# Basic blocks
|
||
|
####################
|
||
|
|
||
|
|
||
|
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
||
|
# helper selecting activation
|
||
|
# neg_slope: for leakyrelu and init of prelu
|
||
|
# n_prelu: for p_relu num_parameters
|
||
|
act_type = act_type.lower()
|
||
|
if act_type == 'relu':
|
||
|
layer = nn.ReLU(inplace)
|
||
|
elif act_type == 'leakyrelu':
|
||
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||
|
elif act_type == 'prelu':
|
||
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||
|
else:
|
||
|
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||
|
return layer
|
||
|
|
||
|
def norm(norm_type, nc):
|
||
|
# helper selecting normalization layer
|
||
|
norm_type = norm_type.lower()
|
||
|
if norm_type == 'batch':
|
||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||
|
elif norm_type == 'instance':
|
||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||
|
else:
|
||
|
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||
|
return layer
|
||
|
|
||
|
def pad(pad_type, padding):
|
||
|
# helper selecting padding layer
|
||
|
# if padding is 'zero', do by conv layers
|
||
|
pad_type = pad_type.lower()
|
||
|
if padding == 0:
|
||
|
return None
|
||
|
if pad_type == 'reflect':
|
||
|
layer = nn.ReflectionPad2d(padding)
|
||
|
elif pad_type == 'replicate':
|
||
|
layer = nn.ReplicationPad2d(padding)
|
||
|
else:
|
||
|
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||
|
return layer
|
||
|
|
||
|
|
||
|
def get_valid_padding(kernel_size, dilation):
|
||
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||
|
padding = (kernel_size - 1) // 2
|
||
|
return padding
|
||
|
|
||
|
|
||
|
class ConcatBlock(nn.Module):
|
||
|
# Concat the output of a submodule to its input
|
||
|
def __init__(self, submodule):
|
||
|
super(ConcatBlock, self).__init__()
|
||
|
self.sub = submodule
|
||
|
|
||
|
def forward(self, x):
|
||
|
output = torch.cat((x, self.sub(x)), dim=1)
|
||
|
return output
|
||
|
|
||
|
def __repr__(self):
|
||
|
tmpstr = 'Identity .. \n|'
|
||
|
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||
|
tmpstr = tmpstr + modstr
|
||
|
return tmpstr
|
||
|
|
||
|
|
||
|
class ShortcutBlock(nn.Module):
|
||
|
#Elementwise sum the output of a submodule to its input
|
||
|
def __init__(self, submodule):
|
||
|
super(ShortcutBlock, self).__init__()
|
||
|
self.sub = submodule
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x, self.sub
|
||
|
|
||
|
def __repr__(self):
|
||
|
tmpstr = 'Identity + \n|'
|
||
|
modstr = self.sub.__repr__().replace('\n', '\n|')
|
||
|
tmpstr = tmpstr + modstr
|
||
|
return tmpstr
|
||
|
|
||
|
|
||
|
def sequential(*args):
|
||
|
# Flatten Sequential. It unwraps nn.Sequential.
|
||
|
if len(args) == 1:
|
||
|
if isinstance(args[0], OrderedDict):
|
||
|
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||
|
return args[0] # No sequential is needed.
|
||
|
modules = []
|
||
|
for module in args:
|
||
|
if isinstance(module, nn.Sequential):
|
||
|
for submodule in module.children():
|
||
|
modules.append(submodule)
|
||
|
elif isinstance(module, nn.Module):
|
||
|
modules.append(module)
|
||
|
return nn.Sequential(*modules)
|
||
|
|
||
|
|
||
|
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
|
||
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
||
|
'''
|
||
|
Conv layer with padding, normalization, activation
|
||
|
mode: CNA --> Conv -> Norm -> Act
|
||
|
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||
|
'''
|
||
|
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||
|
padding = get_valid_padding(kernel_size, dilation)
|
||
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||
|
padding = padding if pad_type == 'zero' else 0
|
||
|
|
||
|
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
||
|
dilation=dilation, bias=bias, groups=groups)
|
||
|
a = act(act_type) if act_type else None
|
||
|
if 'CNA' in mode:
|
||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||
|
return sequential(p, c, n, a)
|
||
|
elif mode == 'NAC':
|
||
|
if norm_type is None and act_type is not None:
|
||
|
a = act(act_type, inplace=False)
|
||
|
# Important!
|
||
|
# input----ReLU(inplace)----Conv--+----output
|
||
|
# |________________________|
|
||
|
# inplace ReLU will modify the input, therefore wrong output
|
||
|
n = norm(norm_type, in_nc) if norm_type else None
|
||
|
return sequential(n, a, p, c)
|
||
|
|
||
|
|
||
|
####################
|
||
|
# Useful blocks
|
||
|
####################
|
||
|
|
||
|
class ResNetBlock(nn.Module):
|
||
|
'''
|
||
|
ResNet Block, 3-3 style
|
||
|
with extra residual scaling used in EDSR
|
||
|
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||
|
'''
|
||
|
|
||
|
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
||
|
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
||
|
super(ResNetBlock, self).__init__()
|
||
|
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||
|
norm_type, act_type, mode)
|
||
|
if mode == 'CNA':
|
||
|
act_type = None
|
||
|
if mode == 'CNAC': # Residual path: |-CNAC-|
|
||
|
act_type = None
|
||
|
norm_type = None
|
||
|
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
||
|
norm_type, act_type, mode)
|
||
|
# if in_nc != out_nc:
|
||
|
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
||
|
# None, None)
|
||
|
# print('Need a projecter in ResNetBlock.')
|
||
|
# else:
|
||
|
# self.project = lambda x:x
|
||
|
self.res = sequential(conv0, conv1)
|
||
|
self.res_scale = res_scale
|
||
|
|
||
|
def forward(self, x):
|
||
|
res = self.res(x).mul(self.res_scale)
|
||
|
return x + res
|
||
|
|
||
|
|
||
|
class ResidualDenseBlock_5C(nn.Module):
|
||
|
'''
|
||
|
Residual Dense Block
|
||
|
style: 5 convs
|
||
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||
|
'''
|
||
|
|
||
|
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||
|
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||
|
# gc: growth channel, i.e. intermediate channels
|
||
|
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||
|
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||
|
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||
|
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
||
|
norm_type=norm_type, act_type=act_type, mode=mode)
|
||
|
if mode == 'CNA':
|
||
|
last_act = None
|
||
|
else:
|
||
|
last_act = act_type
|
||
|
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
||
|
norm_type=norm_type, act_type=last_act, mode=mode)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x1 = self.conv1(x)
|
||
|
x2 = self.conv2(torch.cat((x, x1), 1))
|
||
|
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||
|
return x5.mul(0.2) + x
|
||
|
|
||
|
|
||
|
class RRDB(nn.Module):
|
||
|
'''
|
||
|
Residual in Residual Dense Block
|
||
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||
|
'''
|
||
|
|
||
|
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
||
|
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
||
|
super(RRDB, self).__init__()
|
||
|
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||
|
norm_type, act_type, mode)
|
||
|
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||
|
norm_type, act_type, mode)
|
||
|
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
||
|
norm_type, act_type, mode)
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = self.RDB1(x)
|
||
|
out = self.RDB2(out)
|
||
|
out = self.RDB3(out)
|
||
|
return out.mul(0.2) + x
|
||
|
|
||
|
|
||
|
####################
|
||
|
# Upsampler
|
||
|
####################
|
||
|
|
||
|
|
||
|
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||
|
pad_type='zero', norm_type=None, act_type='relu'):
|
||
|
'''
|
||
|
Pixel shuffle layer
|
||
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||
|
Neural Network, CVPR17)
|
||
|
'''
|
||
|
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
|
||
|
pad_type=pad_type, norm_type=None, act_type=None)
|
||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||
|
|
||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||
|
a = act(act_type) if act_type else None
|
||
|
return sequential(conv, pixel_shuffle, n, a)
|
||
|
|
||
|
|
||
|
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
||
|
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
||
|
# Up conv
|
||
|
# described in https://distill.pub/2016/deconv-checkerboard/
|
||
|
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||
|
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
|
||
|
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
||
|
return sequential(upsample, conv)
|