''' network architecture for EDVR ''' import functools import torch import torch.nn as nn import torch.nn.functional as F import models.archs.arch_util as arch_util try: from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN except ImportError: raise ImportError('Failed to import DCNv2 module.') class Predeblur_ResNet_Pyramid(nn.Module): def __init__(self, nf=128, HR_in=False): ''' HR_in: True if the inputs are high spatial size ''' super(Predeblur_ResNet_Pyramid, self).__init__() self.HR_in = True if HR_in else False if self.HR_in: self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) self.RB_L1_1 = basic_block() self.RB_L1_2 = basic_block() self.RB_L1_3 = basic_block() self.RB_L1_4 = basic_block() self.RB_L1_5 = basic_block() self.RB_L2_1 = basic_block() self.RB_L2_2 = basic_block() self.RB_L3_1 = basic_block() self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, x): if self.HR_in: L1_fea = self.lrelu(self.conv_first_1(x)) L1_fea = self.lrelu(self.conv_first_2(L1_fea)) L1_fea = self.lrelu(self.conv_first_3(L1_fea)) else: L1_fea = self.lrelu(self.conv_first(x)) L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea)) L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea)) L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.RB_L2_1(L2_fea) + L3_fea L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea))) return out class PCD_Align(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. ''' def __init__(self, nf=64, groups=8): super(PCD_Align, self).__init__() # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, extra_offset_mask=True) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, extra_offset_mask=True) self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, extra_offset_mask=True) self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # Cascading DCN self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, extra_offset_mask=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, ref_fea_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset])) # L2 L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset]) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset]) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) offset = self.lrelu(self.cas_offset_conv1(offset)) offset = self.lrelu(self.cas_offset_conv2(offset)) L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset])) return L1_fea class TSA_Fusion(nn.Module): ''' Temporal Spatial Attention fusion module Temporal: correlation; Spatial: 3 pyramid levels. ''' def __init__(self, nf=64, nframes=5, center=2): super(TSA_Fusion, self).__init__() self.center = center # temporal attention (before fusion conv) self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # fusion conv: using 1x1 to save parameters and computation self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) # spatial attention (after fusion conv) self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True) self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, aligned_fea): B, N, C, H, W = aligned_fea.size() # N video frames #### temporal attention emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W] cor_l = [] for i in range(N): emb_nbr = emb[:, i, :, :, :] cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W cor_l.append(cor_tmp) cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob #### fusion fea = self.lrelu(self.fea_fusion(aligned_fea)) #### spatial attention att = self.lrelu(self.sAtt_1(aligned_fea)) att_max = self.maxpool(att) att_avg = self.avgpool(att) att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1))) # pyramid levels att_L = self.lrelu(self.sAtt_L1(att)) att_max = self.maxpool(att_L) att_avg = self.avgpool(att_L) att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1))) att_L = self.lrelu(self.sAtt_L3(att_L)) att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False) att = self.lrelu(self.sAtt_3(att)) att = att + att_L att = self.lrelu(self.sAtt_4(att)) att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False) att = self.sAtt_5(att) att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att))) att = torch.sigmoid(att) fea = fea * att * 2 + att_add return fea class EDVR(nn.Module): def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None, predeblur=False, HR_in=False, w_TSA=True): super(EDVR, self).__init__() self.nf = nf self.center = nframes // 2 if center is None else center self.is_predeblur = True if predeblur else False self.HR_in = True if HR_in else False self.w_TSA = w_TSA ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) #### extract features (for each frame) if self.is_predeblur: self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True) else: if self.HR_in: self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.pcd_align = PCD_Align(nf=nf, groups=groups) if self.w_TSA: self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center) else: self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) #### reconstruction self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs) #### upsampling self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True) #### activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, x): B, N, C, H, W = x.size() # N video frames x_center = x[:, self.center, :, :, :].contiguous() #### extract LR features # L1 if self.is_predeblur: L1_fea = self.pre_deblur(x.view(-1, C, H, W)) L1_fea = self.conv_1x1(L1_fea) if self.HR_in: H, W = H // 4, W // 4 else: if self.HR_in: L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W))) L1_fea = self.lrelu(self.conv_first_2(L1_fea)) L1_fea = self.lrelu(self.conv_first_3(L1_fea)) H, W = H // 4, W // 4 else: L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W))) L1_fea = self.feature_extraction(L1_fea) # L2 L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea)) # L3 L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea)) L1_fea = L1_fea.view(B, N, -1, H, W) L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2) L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4) #### pcd align # ref feature list ref_fea_l = [ L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), L3_fea[:, self.center, :, :, :].clone() ] aligned_fea = [] for i in range(N): nbr_fea_l = [ L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), L3_fea[:, i, :, :, :].clone() ] aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l)) aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] if not self.w_TSA: aligned_fea = aligned_fea.view(B, -1, H, W) fea = self.tsa_fusion(aligned_fea) out = self.recon_trunk(fea) out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) out = self.lrelu(self.HRconv(out)) out = self.conv_last(out) if self.HR_in: base = x_center else: base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) out += base return out