forked from mrq/DL-Art-School
138 lines
4.7 KiB
Python
138 lines
4.7 KiB
Python
|
'''PyTorch implementation of TOFlow
|
||
|
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
|
||
|
Code reference:
|
||
|
1. https://github.com/anchen1011/toflow
|
||
|
2. https://github.com/Coldog2333/pytoflow
|
||
|
'''
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from .arch_util import flow_warp
|
||
|
|
||
|
|
||
|
def normalize(x):
|
||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||
|
return (x - mean) / std
|
||
|
|
||
|
|
||
|
def denormalize(x):
|
||
|
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
|
||
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
|
||
|
return x * std + mean
|
||
|
|
||
|
|
||
|
class SpyNet_Block(nn.Module):
|
||
|
'''A submodule of SpyNet.'''
|
||
|
|
||
|
def __init__(self):
|
||
|
super(SpyNet_Block, self).__init__()
|
||
|
|
||
|
self.block = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||
|
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
|
||
|
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
||
|
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
|
||
|
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
||
|
|
||
|
def forward(self, x):
|
||
|
'''
|
||
|
input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
|
||
|
output: estimated flow - (B, 2, H, W)
|
||
|
'''
|
||
|
return self.block(x)
|
||
|
|
||
|
|
||
|
class SpyNet(nn.Module):
|
||
|
'''SpyNet for estimating optical flow
|
||
|
Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
|
||
|
|
||
|
def __init__(self):
|
||
|
super(SpyNet, self).__init__()
|
||
|
|
||
|
self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
|
||
|
|
||
|
def forward(self, ref, nbr):
|
||
|
'''Estimating optical flow in coarse level, upsample, and estimate in fine level
|
||
|
input: ref: reference image - [B, 3, H, W]
|
||
|
nbr: the neighboring image to be warped - [B, 3, H, W]
|
||
|
output: estimated optical flow - [B, 2, H, W]
|
||
|
'''
|
||
|
B, C, H, W = ref.size()
|
||
|
ref = [ref]
|
||
|
nbr = [nbr]
|
||
|
|
||
|
for _ in range(3):
|
||
|
ref.insert(
|
||
|
0,
|
||
|
nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
|
||
|
count_include_pad=False))
|
||
|
nbr.insert(
|
||
|
0,
|
||
|
nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
|
||
|
count_include_pad=False))
|
||
|
|
||
|
flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
|
||
|
|
||
|
for i in range(4):
|
||
|
flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
|
||
|
align_corners=True) * 2.0
|
||
|
flow = flow_up + self.blocks[i](torch.cat(
|
||
|
[ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
|
||
|
return flow
|
||
|
|
||
|
|
||
|
class TOFlow(nn.Module):
|
||
|
def __init__(self, adapt_official=False):
|
||
|
super(TOFlow, self).__init__()
|
||
|
|
||
|
self.SpyNet = SpyNet()
|
||
|
|
||
|
self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
|
||
|
self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
|
||
|
self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
|
||
|
self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
|
||
|
|
||
|
self.relu = nn.ReLU(inplace=True)
|
||
|
|
||
|
self.adapt_official = adapt_official # True if using translated official weights else False
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
input: x: input frames - [B, 7, 3, H, W]
|
||
|
output: SR reference frame - [B, 3, H, W]
|
||
|
"""
|
||
|
|
||
|
B, T, C, H, W = x.size()
|
||
|
x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
|
||
|
|
||
|
ref_idx = 3
|
||
|
x_ref = x[:, ref_idx, :, :, :]
|
||
|
|
||
|
# In the official torch code, the 0-th frame is the reference frame
|
||
|
if self.adapt_official:
|
||
|
x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
|
||
|
ref_idx = 0
|
||
|
|
||
|
x_warped = []
|
||
|
for i in range(7):
|
||
|
if i == ref_idx:
|
||
|
x_warped.append(x_ref)
|
||
|
else:
|
||
|
x_nbr = x[:, i, :, :, :]
|
||
|
flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
|
||
|
x_warped.append(flow_warp(x_nbr, flow))
|
||
|
x_warped = torch.stack(x_warped, dim=1)
|
||
|
|
||
|
x = x_warped.view(B, -1, H, W)
|
||
|
x = self.relu(self.conv_3x7_64_9x9(x))
|
||
|
x = self.relu(self.conv_64_64_9x9(x))
|
||
|
x = self.relu(self.conv_64_64_1x1(x))
|
||
|
x = self.conv_64_3_1x1(x) + x_ref
|
||
|
|
||
|
return denormalize(x)
|