DL-Art-School/codes/models/steps/stereoscopic.py

22 lines
692 B
Python
Raw Normal View History

import torch
from torch.cuda.amp import autocast
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
def create_stereoscopic_injector(opt, env):
type = opt['type']
if type == 'stereoscopic_resample':
return ResampleInjector(opt, env)
return None
class ResampleInjector(Injector):
def __init__(self, opt, env):
super(ResampleInjector, self).__init__(opt, env)
self.resample = Resample2d()
self.flow = opt['flowfield']
def forward(self, state):
with autocast(enabled=False):
return {self.output: self.resample(state[self.input], state[self.flow])}