import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model

from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network


class ConditioningEncoder(nn.Module):
    def __init__(self,
                 cond_dim,
                 embedding_dim,
                 attn_blocks=6,
                 num_attn_heads=8,
                 dropout=.1,
                 do_checkpointing=False):
        super().__init__()
        self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
        self.attn = Encoder(
                dim=embedding_dim,
                depth=attn_blocks,
                heads=num_attn_heads,
                ff_dropout=dropout,
                attn_dropout=dropout,
                use_rmsnorm=True,
                ff_glu=True,
                rotary_pos_emb=True,
                zero_init_branch_output=True,
                ff_mult=2,
            )
        self.dim = embedding_dim
        self.do_checkpointing = do_checkpointing

    def forward(self, x):
        h = self.init(x).permute(0,2,1)
        h = self.attn(h).permute(0,2,1)
        return h.mean(-1)


class ConditioningAR(nn.Module):
    def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
        super().__init__()
        self.cond_encoder = ConditioningEncoder(256, dim)
        self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
        self.unconditioned_percentage = cond_free_percent
        self.fp16 = fp16

        self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
                                 n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
                                 use_cache=False)
        self.gpt = GPT2Model(self.config)
        del self.gpt.wte  # Unused, we'll do our own embeddings.

        self.embeddings = nn.Embedding(num_vectors, dim)
        self.head = nn.Linear(dim, num_vectors)

    def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
        unused_params = []

        cond = self.cond_encoder(conditioning)
        if self.training and self.unconditioned_percentage > 0:
            unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
            cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
            unused_params.append(self.cond_free_emb)

        h = self.embeddings(cheater_codes)
        h = torch.cat([cond.unsqueeze(1), h], dim=1)
        targets = cheater_codes  # Since we padded above by 1, the input alignment works.

        with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
            h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state

            if return_latent:
                return h.float()

            logits = self.head(h[:,:-1]).permute(0,2,1)
            loss = F.cross_entropy(logits, targets, reduction="none")

        # Perform masking
        if code_lengths is not None:
            mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
            loss = loss * mask
        loss = loss.mean()

        unused_adder = 0
        for p in unused_params:
            unused_adder = unused_adder + p.mean() * 0
        loss = loss + unused_adder

        return loss

    def get_grad_norm_parameter_groups(self):
        groups = {
            'gpt': list(self.gpt.parameters()),
            'head': list(self.head.parameters()),
            'embeddings': list(self.embeddings.parameters()),
            'conditioning_encoder': list(self.cond_encoder.parameters()),
        }
        return groups


@register_model
def register_cheater_gen_ar(opt_net, opt):
    return ConditioningAR(**opt_get(opt_net, ['kwargs'], {}))


def test_ar():
    model = ConditioningAR(512, 8, cond_free_percent=.5)
    print_network(model)

    codes = torch.randint(0,8192, (2,400))
    cond = torch.randn(2,256,400)
    cl = torch.tensor([200,10])
    codes[1,10:] = 2
    model(codes, cond, cl)
    pg = model.get_grad_norm_parameter_groups()



if __name__ == '__main__':
    test_ar()