DL-Art-School/codes/models/archs/pyramid_arch.py

99 lines
4.6 KiB
Python
Raw Normal View History

import torch
from torch import nn
from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from utils.util import checkpoint
import torch.nn.functional as F
class Pyramid(nn.Module):
def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu,
norm=True, return_outlevels=False):
super(Pyramid, self).__init__()
levels = []
current_filters = nf
self.return_outlevels = return_outlevels
for d in range(depth):
level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)]
current_filters = int(current_filters*scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.Sequential(*level))
self.downsamples = nn.ModuleList(levels)
if processing_at_point > 0:
point_processor = []
for p in range(processing_at_point):
point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
self.point_processor = nn.Sequential(*point_processor)
else:
self.point_processor = None
levels = []
for d in range(depth):
level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)]
current_filters = int(current_filters / scale_per_level)
for pc in range(processing_convs_per_layer):
level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False))
levels.append(nn.ModuleList(level))
self.upsamples = nn.ModuleList(levels)
def forward(self, x):
passthroughs = []
fea = x
for lvl in self.downsamples:
passthroughs.append(fea)
fea = lvl(fea)
out_levels = []
fea = self.point_processor(fea)
for i, lvl in enumerate(self.upsamples):
out_levels.append(fea)
for j, sublvl in enumerate(lvl):
if j == 0:
fea = sublvl(fea, passthroughs[-1-i])
else:
fea = sublvl(fea)
out_levels.append(fea)
if self.return_outlevels:
return tuple(out_levels)
else:
return fea
class BasicResamplingFlowNet(nn.Module):
def create_termini(self, filters):
return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
def __init__(self, nf, resample_scale=1):
super(BasicResamplingFlowNet, self).__init__()
self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True)
self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True)
self.termini = nn.ModuleList([self.create_termini(nf*1.5**3),
self.create_termini(nf*1.5**2),
self.create_termini(nf*1.5)])
self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True),
ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False),
ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True),
ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True),
nn.Tanh())
self.scale = resample_scale
self.resampler = Resample2d()
def forward(self, left, right):
fea = self.initial_conv(torch.cat([left, right], dim=1))
levels = checkpoint(self.pyramid, fea)
flos = []
compares = []
for i, level in enumerate(levels):
if i == 3:
flow = checkpoint(self.terminus, level) * self.scale
else:
flow = self.termini[i](level) * self.scale
img_scale = 1/2**(3-i)
flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float()))
compares.append(F.interpolate(right, scale_factor=img_scale, mode="area"))
flos_structural_var = torch.var(flos[-1], dim=[-1,-2])
return flos, compares, flos_structural_var