'''PyTorch implementation of TOFlow Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 Code reference: 1. https://github.com/anchen1011/toflow 2. https://github.com/Coldog2333/pytoflow ''' import torch import torch.nn as nn from .arch_util import flow_warp def normalize(x): mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) return (x - mean) / std def denormalize(x): mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) return x * std + mean class SpyNet_Block(nn.Module): '''A submodule of SpyNet.''' def __init__(self): super(SpyNet_Block, self).__init__() self.block = nn.Sequential( nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(16), nn.ReLU(inplace=True), nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) def forward(self, x): ''' input: x: [ref im, nbr im, initial flow] - (B, 8, H, W) output: estimated flow - (B, 2, H, W) ''' return self.block(x) class SpyNet(nn.Module): '''SpyNet for estimating optical flow Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016''' def __init__(self): super(SpyNet, self).__init__() self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)]) def forward(self, ref, nbr): '''Estimating optical flow in coarse level, upsample, and estimate in fine level input: ref: reference image - [B, 3, H, W] nbr: the neighboring image to be warped - [B, 3, H, W] output: estimated optical flow - [B, 2, H, W] ''' B, C, H, W = ref.size() ref = [ref] nbr = [nbr] for _ in range(3): ref.insert( 0, nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) nbr.insert( 0, nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2, count_include_pad=False)) flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0]) for i in range(4): flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 flow = flow_up + self.blocks[i](torch.cat( [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) return flow class TOFlow(nn.Module): def __init__(self, adapt_official=False): super(TOFlow, self).__init__() self.SpyNet = SpyNet() self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4) self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4) self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1) self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1) self.relu = nn.ReLU(inplace=True) self.adapt_official = adapt_official # True if using translated official weights else False def forward(self, x): """ input: x: input frames - [B, 7, 3, H, W] output: SR reference frame - [B, 3, H, W] """ B, T, C, H, W = x.size() x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W) ref_idx = 3 x_ref = x[:, ref_idx, :, :, :] # In the official torch code, the 0-th frame is the reference frame if self.adapt_official: x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] ref_idx = 0 x_warped = [] for i in range(7): if i == ref_idx: x_warped.append(x_ref) else: x_nbr = x[:, i, :, :, :] flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1) x_warped.append(flow_warp(x_nbr, flow)) x_warped = torch.stack(x_warped, dim=1) x = x_warped.view(B, -1, H, W) x = self.relu(self.conv_3x7_64_9x9(x)) x = self.relu(self.conv_64_64_9x9(x)) x = self.relu(self.conv_64_64_1x1(x)) x = self.conv_64_3_1x1(x) + x_ref return denormalize(x)