import torch
from torch.cuda.amp import autocast
from models.flownet2.networks import Resample2d
from models.flownet2 import flow2img
from trainer.inject import Injector


def create_stereoscopic_injector(opt, env):
    type = opt['type']
    if type == 'stereoscopic_resample':
        return ResampleInjector(opt, env)
    elif type == 'stereoscopic_flow2image':
        return Flow2Image(opt, env)
    return None


class ResampleInjector(Injector):
    def __init__(self, opt, env):
        super(ResampleInjector, self).__init__(opt, env)
        self.resample = Resample2d()
        self.flow = opt['flowfield']

    def forward(self, state):
        with autocast(enabled=False):
            return {self.output: self.resample(state[self.input], state[self.flow])}


# Converts a flowfield to an image representation for viewing purposes.
# Uses flownet's implementation to do so. Which really sucks. TODO: just do my own implementation in the future.
# Note: this is not differentiable and is only usable for debugging purposes.
class Flow2Image(Injector):
    def __init__(self, opt, env):
        super(Flow2Image, self).__init__(opt, env)

    def forward(self, state):
        with torch.no_grad():
            flo = state[self.input].cpu()
            bs, c, h, w = flo.shape
            flo = flo.permute(0, 2, 3, 1)  # flow2img works in numpy space for some reason..
            imgs = torch.empty_like(flo)
            flo = flo.numpy()
            for b in range(bs):
                img = flow2img(flo[b])  # Note that this returns the image in an integer format.
                img = torch.tensor(img, dtype=torch.float) / 255
                imgs[b] = img
            imgs = imgs.permute(0, 3, 1, 2)
            return {self.output: imgs}