import os

import numpy
import torch
import torch.nn as nn
from matplotlib import pyplot
from torch.utils.tensorboard import SummaryWriter

from data.audio.unsupervised_audio_dataset import load_audio
from models.gpt_voice.gpt_asr_hf import GptAsrHf
from models.tacotron2.text import text_to_sequence
from trainer.injectors.base_injectors import MelSpectrogramInjector

if __name__ == '__main__':
    audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0)
    audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1]))
    mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
    mel = mel_inj({'in': audio_data})['out'].cuda()
    actual_text = 'and it doesn\'t take very long.'
    labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda()

    model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8)
    model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth'))
    model = model.cuda()

    with torch.no_grad():
        attentions = model(mel, labels, return_attentions=True)
        attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames]
        attentions = attentions.sum(0).sum(1).squeeze()

    xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)]
    os.makedirs('results', exist_ok=True)
    logger = SummaryWriter('results')
    for e, character_attn in enumerate(attentions):
        if e >= len(actual_text):
            break
        fig = pyplot.figure()
        ax = fig.add_axes([0,0,1,1])
        ax.bar(xs, character_attn.cpu().numpy())
        logger.add_figure(f'{e}_{actual_text[e]}', fig)