9c3d059ef0
Only supports basic losses for now, though.
22 lines
692 B
Python
22 lines
692 B
Python
import torch
|
|
from torch.cuda.amp import autocast
|
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
|
from models.steps.injectors import Injector
|
|
|
|
|
|
def create_stereoscopic_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'stereoscopic_resample':
|
|
return ResampleInjector(opt, env)
|
|
return None
|
|
|
|
|
|
class ResampleInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super(ResampleInjector, self).__init__(opt, env)
|
|
self.resample = Resample2d()
|
|
self.flow = opt['flowfield']
|
|
|
|
def forward(self, state):
|
|
with autocast(enabled=False):
|
|
return {self.output: self.resample(state[self.input], state[self.flow])} |