DL-Art-School/codes/models/steps/stereoscopic.py
James Betker 9c3d059ef0 Updates to be able to train flownet2 in ExtensibleTrainer
Only supports basic losses for now, though.
2020-10-24 11:56:39 -06:00

22 lines
692 B
Python

import torch
from torch.cuda.amp import autocast
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
def create_stereoscopic_injector(opt, env):
type = opt['type']
if type == 'stereoscopic_resample':
return ResampleInjector(opt, env)
return None
class ResampleInjector(Injector):
def __init__(self, opt, env):
super(ResampleInjector, self).__init__(opt, env)
self.resample = Resample2d()
self.flow = opt['flowfield']
def forward(self, state):
with autocast(enabled=False):
return {self.output: self.resample(state[self.input], state[self.flow])}