forked from mrq/DL-Art-School
22 lines
692 B
Python
22 lines
692 B
Python
|
import torch
|
||
|
from torch.cuda.amp import autocast
|
||
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||
|
from models.steps.injectors import Injector
|
||
|
|
||
|
|
||
|
def create_stereoscopic_injector(opt, env):
|
||
|
type = opt['type']
|
||
|
if type == 'stereoscopic_resample':
|
||
|
return ResampleInjector(opt, env)
|
||
|
return None
|
||
|
|
||
|
|
||
|
class ResampleInjector(Injector):
|
||
|
def __init__(self, opt, env):
|
||
|
super(ResampleInjector, self).__init__(opt, env)
|
||
|
self.resample = Resample2d()
|
||
|
self.flow = opt['flowfield']
|
||
|
|
||
|
def forward(self, state):
|
||
|
with autocast(enabled=False):
|
||
|
return {self.output: self.resample(state[self.input], state[self.flow])}
|