import math

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from models.segformer.backbone import backbone50
from trainer.networks import register_model


# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d(input, index):
    b, c, h, w = input.shape
    nodim = input.view(b, c, h * w)
    ind_nd = index[:, 0]*w + index[:, 1]
    ind_nd = ind_nd.unsqueeze(1)
    ind_nd = ind_nd.repeat((1, c))
    ind_nd = ind_nd.unsqueeze(2)
    result = torch.gather(nodim, dim=2, index=ind_nd)
    result = result.squeeze()
    if b == 1:
        result = result.unsqueeze(0)
    return result


class DilatorModule(nn.Module):
    def __init__(self, input_channels, output_channels, max_dilation):
        super().__init__()
        self.max_dilation = max_dilation
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=1, bias=True)
        if max_dilation > 1:
            self.bn = nn.BatchNorm2d(input_channels)
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True)
        self.dense = nn.Linear(input_channels, output_channels, bias=True)

    def forward(self, inp, loc):
        x = self.conv1(inp)
        if self.max_dilation > 1:
            x = self.bn(self.relu(x))
            x = self.conv2(x)

        # This can be made more efficient by only computing these convolutions across a subset of the image. Possibly.
        x = gather_2d(x, loc).contiguous()
        return self.dense(x)


# Grabbed from torch examples: https://github.com/pytorch/examples/tree/master/https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65:7
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x


# Simple mean() layer encoded into a class so that BYOL can grab it.
class Tail(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mean(dim=0)


class Segformer(nn.Module):
    def __init__(self, latent_channels=1024, layers=8):
        super().__init__()
        self.backbone = backbone50()
        backbone_channels = [256, 512, 1024, 2048]
        dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
        final_latent_channels = latent_channels
        dilators = []
        for ic, dis in zip(backbone_channels, dilations):
            layer_dilators = []
            for di in dis:
                layer_dilators.append(DilatorModule(ic, final_latent_channels, di))
            dilators.append(nn.ModuleList(layer_dilators))
        self.dilators = nn.ModuleList(dilators)

        self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
        self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
        self.tail = Tail()

    def forward(self, img=None, layers=None, pos=None, return_layers=False):
        assert img is not None or layers is not None
        if img is not None:
            bs = img.shape[0]
            layers = self.backbone(img)
        else:
            bs = layers[0].shape[0]
        if return_layers:
            return layers

        # A single position can be optionally given, in which case we need to expand it to represent the entire input.
        if pos.shape == (2,):
            pos = pos.unsqueeze(0).repeat(bs, 1)

        set = []
        pos = pos // 4
        for layer_out, dilator in zip(layers, self.dilators):
            for subdilator in dilator:
                set.append(subdilator(layer_out, pos))
            pos = pos // 2

        # The torch transformer expects the set dimension to be 0.
        set = torch.stack(set, dim=0)
        set = self.token_position_encoder(set)
        set = self.transformer_layers(set)
        return self.tail(set)


@register_model
def register_segformer(opt_net, opt):
    return Segformer()


if __name__ == '__main__':
    model = Segformer().to('cuda')
    for j in tqdm(range(1000)):
        test_tensor = torch.randn(64,3,224,224).cuda()
        print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape)