import torch from torch import nn from models.archs.arch_util import ConvGnLelu, UpconvBlock, ExpansionBlock from models.flownet2.networks.resample2d_package.resample2d import Resample2d from utils.util import checkpoint import torch.nn.functional as F class Pyramid(nn.Module): def __init__(self, nf, depth, processing_convs_per_layer, processing_at_point, scale_per_level=2, block=ConvGnLelu, norm=True, return_outlevels=False): super(Pyramid, self).__init__() levels = [] current_filters = nf self.return_outlevels = return_outlevels for d in range(depth): level = [block(current_filters, int(current_filters*scale_per_level), kernel_size=3, stride=2, activation=True, norm=False, bias=False)] current_filters = int(current_filters*scale_per_level) for pc in range(processing_convs_per_layer): level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) levels.append(nn.Sequential(*level)) self.downsamples = nn.ModuleList(levels) if processing_at_point > 0: point_processor = [] for p in range(processing_at_point): point_processor.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) self.point_processor = nn.Sequential(*point_processor) else: self.point_processor = None levels = [] for d in range(depth): level = [ExpansionBlock(current_filters, int(current_filters / scale_per_level), block=block)] current_filters = int(current_filters / scale_per_level) for pc in range(processing_convs_per_layer): level.append(block(current_filters, current_filters, kernel_size=3, activation=True, norm=norm, bias=False)) levels.append(nn.ModuleList(level)) self.upsamples = nn.ModuleList(levels) def forward(self, x): passthroughs = [] fea = x for lvl in self.downsamples: passthroughs.append(fea) fea = lvl(fea) out_levels = [] fea = self.point_processor(fea) for i, lvl in enumerate(self.upsamples): out_levels.append(fea) for j, sublvl in enumerate(lvl): if j == 0: fea = sublvl(fea, passthroughs[-1-i]) else: fea = sublvl(fea) out_levels.append(fea) if self.return_outlevels: return tuple(out_levels) else: return fea class BasicResamplingFlowNet(nn.Module): def create_termini(self, filters): return nn.Sequential(ConvGnLelu(int(filters), 2, kernel_size=3, activation=False, norm=False, bias=True), nn.Tanh()) def __init__(self, nf, resample_scale=1): super(BasicResamplingFlowNet, self).__init__() self.initial_conv = ConvGnLelu(6, nf, kernel_size=7, activation=False, norm=False, bias=True) self.pyramid = Pyramid(nf, 3, 0, 1, 1.5, return_outlevels=True) self.termini = nn.ModuleList([self.create_termini(nf*1.5**3), self.create_termini(nf*1.5**2), self.create_termini(nf*1.5)]) self.terminus = nn.Sequential(ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=True), ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False), ConvGnLelu(nf, nf//2, kernel_size=3, activation=False, norm=False, bias=True), ConvGnLelu(nf//2, 2, kernel_size=3, activation=False, norm=False, bias=True), nn.Tanh()) self.scale = resample_scale self.resampler = Resample2d() def forward(self, left, right): fea = self.initial_conv(torch.cat([left, right], dim=1)) levels = checkpoint(self.pyramid, fea) flos = [] compares = [] for i, level in enumerate(levels): if i == 3: flow = checkpoint(self.terminus, level) * self.scale else: flow = self.termini[i](level) * self.scale img_scale = 1/2**(3-i) flos.append(self.resampler(F.interpolate(left, scale_factor=img_scale, mode="area").float(), flow.float())) compares.append(F.interpolate(right, scale_factor=img_scale, mode="area")) flos_structural_var = torch.var(flos[-1], dim=[-1,-2]) return flos, compares, flos_structural_var