forked from mrq/DL-Art-School
99 lines
4.6 KiB
Python
99 lines
4.6 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from models.archs.arch_util import ConvGnLelu, ExpansionBlock
|
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
|
from utils.util import checkpoint
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Pyramid(nn.Module):
|
|
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
|
|
norm=True, return_outlevels=False):
|
|
super(Pyramid, self).__init__()
|
|
levels = []
|
|
current_filters = nf
|
|
self.return_outlevels = return_outlevels
|
|
for d in range(depth):
|
|
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
|
|
current_filters = int(current_filters*scale_per_level)
|
|
for pc in range(processing_convs_per_layer):
|
|
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
|
levels.append(nn.Sequential(*level))
|
|
self.downsamples = nn.ModuleList(levels)
|
|
if processing_at_point > 0:
|
|
point_processor = []
|
|
for p in range(processing_at_point):
|
|
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
|
self.point_processor = nn.Sequential(*point_processor)
|
|
else:
|
|
self.point_processor = None
|
|
levels = []
|
|
for d in range(depth):
|
|
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
|
|
current_filters = int(current_filters / scale_per_level)
|
|
for pc in range(processing_convs_per_layer):
|
|
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
|
|
levels.append(nn.ModuleList(level))
|
|
self.upsamples = nn.ModuleList(levels)
|
|
|
|
def forward(self, x):
|
|
passthroughs = []
|
|
fea = x
|
|
for lvl in self.downsamples:
|
|
passthroughs.append(fea)
|
|
fea = lvl(fea)
|
|
out_levels = []
|
|
fea = self.point_processor(fea)
|
|
for i, lvl in enumerate(self.upsamples):
|
|
out_levels.append(fea)
|
|
for j, sublvl in enumerate(lvl):
|
|
if j == 0:
|
|
fea = sublvl(fea, passthroughs[-1-i])
|
|
else:
|
|
fea = sublvl(fea)
|
|
|
|
out_levels.append(fea)
|
|
|
|
if self.return_outlevels:
|
|
return tuple(out_levels)
|
|
else:
|
|
return fea
|
|
|
|
|
|
class BasicResamplingFlowNet(nn.Module):
|
|
def create_termini(self, filters):
|
|
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
|
|
nn.Tanh())
|
|
|
|
def __init__(self, nf, resample_scale=1):
|
|
super(BasicResamplingFlowNet, self).__init__()
|
|
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
|
|
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
|
|
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
|
|
self.create_termini(nf*1.5**2),
|
|
self.create_termini(nf*1.5)])
|
|
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
|
|
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
|
|
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
|
|
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
|
|
nn.Tanh())
|
|
self.scale = resample_scale
|
|
self.resampler = Resample2d()
|
|
|
|
def forward(self, left, right):
|
|
fea = self.initial_conv(torch.cat([left, right], dim=1))
|
|
levels = checkpoint(self.pyramid, fea)
|
|
flos = []
|
|
compares = []
|
|
for i, level in enumerate(levels):
|
|
if i == 3:
|
|
flow = checkpoint(self.terminus, level) * self.scale
|
|
else:
|
|
flow = self.termini[i](level) * self.scale
|
|
img_scale = 1/2**(3-i)
|
|
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
|
|
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
|
|
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
|
|
return flos, compares, flos_structural_var
|