DL-Art-School/codes/trainer/custom_training_components/stereoscopic.py

48 lines
1.7 KiB
Python
Raw Normal View History

import torch
from torch.cuda.amp import autocast
from models.archs.flownet2.networks import Resample2d
from models.archs.flownet2 import flow2img
2020-12-18 16:18:34 +00:00
from trainer.injectors import Injector
def create_stereoscopic_injector(opt, env):
type = opt['type']
if type == 'stereoscopic_resample':
return ResampleInjector(opt, env)
elif type == 'stereoscopic_flow2image':
return Flow2Image(opt, env)
return None
class ResampleInjector(Injector):
def __init__(self, opt, env):
super(ResampleInjector, self).__init__(opt, env)
self.resample = Resample2d()
self.flow = opt['flowfield']
def forward(self, state):
with autocast(enabled=False):
return {self.output: self.resample(state[self.input], state[self.flow])}
# Converts a flowfield to an image representation for viewing purposes.
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
# Note: this is not differentiable and is only usable for debugging purposes.
class Flow2Image(Injector):
def __init__(self, opt, env):
super(Flow2Image, self).__init__(opt, env)
def forward(self, state):
with torch.no_grad():
flo = state[self.input].cpu()
bs, c, h, w = flo.shape
flo = flo.permute(0, 2, 3, 1) # flow2img works in numpy space for some reason..
imgs = torch.empty_like(flo)
flo = flo.numpy()
for b in range(bs):
img = flow2img(flo[b]) # Note that this returns the image in an integer format.
img = torch.tensor(img, dtype=torch.float) / 255
imgs[b] = img
imgs = imgs.permute(0, 3, 1, 2)
return {self.output: imgs}