313 lines
14 KiB
Python
313 lines
14 KiB
Python
|
''' network architecture for EDVR '''
|
||
|
import functools
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import models.archs.arch_util as arch_util
|
||
|
try:
|
||
|
from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
|
||
|
except ImportError:
|
||
|
raise ImportError('Failed to import DCNv2 module.')
|
||
|
|
||
|
|
||
|
class Predeblur_ResNet_Pyramid(nn.Module):
|
||
|
def __init__(self, nf=128, HR_in=False):
|
||
|
'''
|
||
|
HR_in: True if the inputs are high spatial size
|
||
|
'''
|
||
|
|
||
|
super(Predeblur_ResNet_Pyramid, self).__init__()
|
||
|
self.HR_in = True if HR_in else False
|
||
|
if self.HR_in:
|
||
|
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||
|
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
else:
|
||
|
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||
|
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||
|
self.RB_L1_1 = basic_block()
|
||
|
self.RB_L1_2 = basic_block()
|
||
|
self.RB_L1_3 = basic_block()
|
||
|
self.RB_L1_4 = basic_block()
|
||
|
self.RB_L1_5 = basic_block()
|
||
|
self.RB_L2_1 = basic_block()
|
||
|
self.RB_L2_2 = basic_block()
|
||
|
self.RB_L3_1 = basic_block()
|
||
|
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
|
||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.HR_in:
|
||
|
L1_fea = self.lrelu(self.conv_first_1(x))
|
||
|
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||
|
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||
|
else:
|
||
|
L1_fea = self.lrelu(self.conv_first(x))
|
||
|
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
|
||
|
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
|
||
|
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
|
||
|
align_corners=False)
|
||
|
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
|
||
|
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
|
||
|
align_corners=False)
|
||
|
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
|
||
|
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
|
||
|
return out
|
||
|
|
||
|
|
||
|
class PCD_Align(nn.Module):
|
||
|
''' Alignment module using Pyramid, Cascading and Deformable convolution
|
||
|
with 3 pyramid levels.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, nf=64, groups=8):
|
||
|
super(PCD_Align, self).__init__()
|
||
|
# L3: level 3, 1/4 spatial size
|
||
|
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||
|
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||
|
extra_offset_mask=True)
|
||
|
# L2: level 2, 1/2 spatial size
|
||
|
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||
|
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||
|
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||
|
extra_offset_mask=True)
|
||
|
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||
|
# L1: level 1, original spatial size
|
||
|
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||
|
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
|
||
|
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||
|
extra_offset_mask=True)
|
||
|
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
||
|
# Cascading DCN
|
||
|
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
|
||
|
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
|
||
|
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
|
||
|
extra_offset_mask=True)
|
||
|
|
||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||
|
|
||
|
def forward(self, nbr_fea_l, ref_fea_l):
|
||
|
'''align other neighboring frames to the reference frame in the feature level
|
||
|
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
|
||
|
'''
|
||
|
# L3
|
||
|
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
|
||
|
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
|
||
|
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
|
||
|
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
|
||
|
# L2
|
||
|
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
|
||
|
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
|
||
|
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
|
||
|
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
|
||
|
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
|
||
|
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
|
||
|
# L1
|
||
|
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
|
||
|
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
|
||
|
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
|
||
|
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
|
||
|
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
|
||
|
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
|
||
|
# Cascading
|
||
|
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
|
||
|
offset = self.lrelu(self.cas_offset_conv1(offset))
|
||
|
offset = self.lrelu(self.cas_offset_conv2(offset))
|
||
|
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
|
||
|
|
||
|
return L1_fea
|
||
|
|
||
|
|
||
|
class TSA_Fusion(nn.Module):
|
||
|
''' Temporal Spatial Attention fusion module
|
||
|
Temporal: correlation;
|
||
|
Spatial: 3 pyramid levels.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, nf=64, nframes=5, center=2):
|
||
|
super(TSA_Fusion, self).__init__()
|
||
|
self.center = center
|
||
|
# temporal attention (before fusion conv)
|
||
|
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
|
||
|
# fusion conv: using 1x1 to save parameters and computation
|
||
|
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||
|
|
||
|
# spatial attention (after fusion conv)
|
||
|
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||
|
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||
|
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
|
||
|
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
|
||
|
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||
|
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||
|
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
||
|
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||
|
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||
|
|
||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||
|
|
||
|
def forward(self, aligned_fea):
|
||
|
B, N, C, H, W = aligned_fea.size() # N video frames
|
||
|
#### temporal attention
|
||
|
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
|
||
|
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
|
||
|
|
||
|
cor_l = []
|
||
|
for i in range(N):
|
||
|
emb_nbr = emb[:, i, :, :, :]
|
||
|
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
|
||
|
cor_l.append(cor_tmp)
|
||
|
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
|
||
|
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
|
||
|
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
|
||
|
|
||
|
#### fusion
|
||
|
fea = self.lrelu(self.fea_fusion(aligned_fea))
|
||
|
|
||
|
#### spatial attention
|
||
|
att = self.lrelu(self.sAtt_1(aligned_fea))
|
||
|
att_max = self.maxpool(att)
|
||
|
att_avg = self.avgpool(att)
|
||
|
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
|
||
|
# pyramid levels
|
||
|
att_L = self.lrelu(self.sAtt_L1(att))
|
||
|
att_max = self.maxpool(att_L)
|
||
|
att_avg = self.avgpool(att_L)
|
||
|
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
|
||
|
att_L = self.lrelu(self.sAtt_L3(att_L))
|
||
|
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
|
||
|
att = self.lrelu(self.sAtt_3(att))
|
||
|
att = att + att_L
|
||
|
att = self.lrelu(self.sAtt_4(att))
|
||
|
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
|
||
|
att = self.sAtt_5(att)
|
||
|
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
|
||
|
att = torch.sigmoid(att)
|
||
|
|
||
|
fea = fea * att * 2 + att_add
|
||
|
return fea
|
||
|
|
||
|
|
||
|
class EDVR(nn.Module):
|
||
|
def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
|
||
|
predeblur=False, HR_in=False, w_TSA=True):
|
||
|
super(EDVR, self).__init__()
|
||
|
self.nf = nf
|
||
|
self.center = nframes // 2 if center is None else center
|
||
|
self.is_predeblur = True if predeblur else False
|
||
|
self.HR_in = True if HR_in else False
|
||
|
self.w_TSA = w_TSA
|
||
|
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
||
|
|
||
|
#### extract features (for each frame)
|
||
|
if self.is_predeblur:
|
||
|
self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
|
||
|
self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
|
||
|
else:
|
||
|
if self.HR_in:
|
||
|
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||
|
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
else:
|
||
|
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
|
||
|
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
|
||
|
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
||
|
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||
|
|
||
|
self.pcd_align = PCD_Align(nf=nf, groups=groups)
|
||
|
if self.w_TSA:
|
||
|
self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
|
||
|
else:
|
||
|
self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
|
||
|
|
||
|
#### reconstruction
|
||
|
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
|
||
|
#### upsampling
|
||
|
self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
|
||
|
self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
|
||
|
self.pixel_shuffle = nn.PixelShuffle(2)
|
||
|
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
|
||
|
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
|
||
|
|
||
|
#### activation function
|
||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||
|
|
||
|
def forward(self, x):
|
||
|
B, N, C, H, W = x.size() # N video frames
|
||
|
x_center = x[:, self.center, :, :, :].contiguous()
|
||
|
|
||
|
#### extract LR features
|
||
|
# L1
|
||
|
if self.is_predeblur:
|
||
|
L1_fea = self.pre_deblur(x.view(-1, C, H, W))
|
||
|
L1_fea = self.conv_1x1(L1_fea)
|
||
|
if self.HR_in:
|
||
|
H, W = H // 4, W // 4
|
||
|
else:
|
||
|
if self.HR_in:
|
||
|
L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
|
||
|
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
|
||
|
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
|
||
|
H, W = H // 4, W // 4
|
||
|
else:
|
||
|
L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
|
||
|
L1_fea = self.feature_extraction(L1_fea)
|
||
|
# L2
|
||
|
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
|
||
|
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
|
||
|
# L3
|
||
|
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
|
||
|
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
|
||
|
|
||
|
L1_fea = L1_fea.view(B, N, -1, H, W)
|
||
|
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
|
||
|
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
|
||
|
|
||
|
#### pcd align
|
||
|
# ref feature list
|
||
|
ref_fea_l = [
|
||
|
L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
|
||
|
L3_fea[:, self.center, :, :, :].clone()
|
||
|
]
|
||
|
aligned_fea = []
|
||
|
for i in range(N):
|
||
|
nbr_fea_l = [
|
||
|
L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
|
||
|
L3_fea[:, i, :, :, :].clone()
|
||
|
]
|
||
|
aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
|
||
|
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
|
||
|
|
||
|
if not self.w_TSA:
|
||
|
aligned_fea = aligned_fea.view(B, -1, H, W)
|
||
|
fea = self.tsa_fusion(aligned_fea)
|
||
|
|
||
|
out = self.recon_trunk(fea)
|
||
|
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
||
|
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
||
|
out = self.lrelu(self.HRconv(out))
|
||
|
out = self.conv_last(out)
|
||
|
if self.HR_in:
|
||
|
base = x_center
|
||
|
else:
|
||
|
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
||
|
out += base
|
||
|
return out
|