import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from trainer.inject import Injector
from trainer.networks import register_model
from utils.util import checkpoint


def create_injector(opt, env):
    type = opt['type']
    if type == 'igpt_resolve':
        return ResolveInjector(opt, env)
    return None


class ResolveInjector(Injector):
    def __init__(self, opt, env):
        super().__init__(opt, env)
        self.gen = opt['generator']
        self.samples = opt['num_samples']
        self.temperature = opt['temperature']

    def forward(self, state):
        gen = self.env['generators'][self.opt['generator']].module
        img = state[self.input]
        b, c, h, w = img.shape
        qimg = gen.quantize(img)
        s, b = qimg.shape
        qimg = qimg[:s//2, :]
        output = qimg.repeat(1, self.samples)

        pad = torch.zeros(1, self.samples, dtype=torch.long).cuda()  # to pad prev output
        with torch.no_grad():
            for _ in range(s//2):
                logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
                logits = logits[-1, :, :] / self.temperature
                probs = F.softmax(logits, dim=-1)
                pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
                output = torch.cat((output, pred), dim=0)
        output = gen.unquantize(output.reshape(h, w, -1))
        return {self.output: output.permute(2,3,0,1).contiguous()}


class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        attn_mask = torch.full(
            (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
        )
        attn_mask = torch.triu(attn_mask, diagonal=1)

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


class iGPT2(nn.Module):
    def __init__(
        self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
    ):
        super().__init__()

        self.centroids = nn.Parameter(
            torch.from_numpy(np.load(centroids_file)), requires_grad=False
        )
        self.embed_dim = embed_dim

        # start of sequence token
        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
        self.position_embeddings = nn.Embedding(num_positions, embed_dim)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(embed_dim, num_heads))

        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_vocab, bias=False)
        self.clf_head = nn.Linear(embed_dim, 10)  # Fixed num_classes, this is not a classifier.

    def squared_euclidean_distance(self, a, b):
        b = torch.transpose(b, 0, 1)
        a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
        b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
        ab = torch.matmul(a, b)
        d = a2 - 2 * ab + b2
        return d

    def quantize(self, x):
        b, c, h, w = x.shape
        # [B, C, H, W] => [B, H, W, C]
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(-1, c)  # flatten to pixels
        d = self.squared_euclidean_distance(x, self.centroids)
        x = torch.argmin(d, 1)
        x = x.view(b, h, w)

        # Reshape output to [seq_len, batch].
        x = x.view(x.shape[0], -1)  # flatten images into sequences
        x = x.transpose(0, 1).contiguous()  # to shape [seq len, batch]
        return x

    def unquantize(self, x):
        return self.centroids[x]

    def forward(self, x, already_quantized=False):
        """
        Expect input as shape [b, c, h, w]
        """

        if not already_quantized:
            x = self.quantize(x)
        length, batch = x.shape

        h = self.token_embeddings(x)

        # prepend sos token
        sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
        h = torch.cat([sos, h[:-1, :, :]], axis=0)

        # add positional embeddings
        positions = torch.arange(length, device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)

        # transformer
        for layer in self.layers:
            h = checkpoint(layer, h)

        h = self.ln_f(h)

        logits = self.head(h)

        return logits, x


@register_model
def register_igpt2(opt_net, opt):
    return iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2,
          opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])