forked from mrq/DL-Art-School
51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
import torch
|
|
from models.flownet2 import flow2img
|
|
from models.flownet2.networks import Resample2d
|
|
from torch.cuda.amp import autocast
|
|
|
|
from dlas.trainer.inject import Injector
|
|
|
|
|
|
def create_stereoscopic_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'stereoscopic_resample':
|
|
return ResampleInjector(opt, env)
|
|
elif type == 'stereoscopic_flow2image':
|
|
return Flow2Image(opt, env)
|
|
return None
|
|
|
|
|
|
class ResampleInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super(ResampleInjector, self).__init__(opt, env)
|
|
self.resample = Resample2d()
|
|
self.flow = opt['flowfield']
|
|
|
|
def forward(self, state):
|
|
with autocast(enabled=False):
|
|
return {self.output: self.resample(state[self.input], state[self.flow])}
|
|
|
|
|
|
# Converts a flowfield to an image representation for viewing purposes.
|
|
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
|
|
# Note: this is not differentiable and is only usable for debugging purposes.
|
|
class Flow2Image(Injector):
|
|
def __init__(self, opt, env):
|
|
super(Flow2Image, self).__init__(opt, env)
|
|
|
|
def forward(self, state):
|
|
with torch.no_grad():
|
|
flo = state[self.input].cpu()
|
|
bs, c, h, w = flo.shape
|
|
# flow2img works in numpy space for some reason..
|
|
flo = flo.permute(0, 2, 3, 1)
|
|
imgs = torch.empty_like(flo)
|
|
flo = flo.numpy()
|
|
for b in range(bs):
|
|
# Note that this returns the image in an integer format.
|
|
img = flow2img(flo[b])
|
|
img = torch.tensor(img, dtype=torch.float) / 255
|
|
imgs[b] = img
|
|
imgs = imgs.permute(0, 3, 1, 2)
|
|
return {self.output: imgs}
|