import torch
import torchvision
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import trainer.eval.evaluator as evaluator
from data import create_dataset
from models.vqvae.kmeans_mask_producer import UResnetMaskProducer
from utils.util import opt_get


class CategorizationLossEvaluator(evaluator.Evaluator):
    def __init__(self, model, opt_eval, env):
        super().__init__(model, opt_eval, env, uses_all_ddp=False)
        self.batch_sz = opt_eval['batch_size']
        assert self.batch_sz is not None
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        self.dataset = create_dataset(opt_eval['dataset'])
        self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=4)
        self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
        self.masking = opt_get(opt_eval, ['masking'], False)
        if self.masking:
            self.mask_producer = UResnetMaskProducer(pretrained_uresnet_path= '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth',
                                                     kmeans_centroid_path='../experiments/k_means_uresnet_imagenet_256.pth',
                                                     mask_scales=[.03125, .0625, .125, .25, .5, 1.0],
                                                     tail_dim=256).to('cuda')

    def accuracy(self, output, target, topk=(1,)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target[None])

            res = []
            for k in topk:
                correct_k = correct[:k].flatten().sum(dtype=torch.float32)
                res.append(correct_k * (100.0 / batch_size))
            return res

    def perform_eval(self):
        counter = 0.0
        ce_loss = 0.0
        top_5_acc = 0.0
        top_1_acc = 0.0

        self.model.eval()
        with torch.no_grad():
            for batch in tqdm(self.dataloader):
                hq, labels = batch['hq'], batch['labels']
                hq = hq.to(self.env['device'])
                labels = labels.to(self.env['device'])
                if self.masking:
                    masks = self.mask_producer(hq)
                    logits = self.model(hq, masks)
                else:
                    logits = self.model(hq)
                if not isinstance(logits, list) and not isinstance(logits, tuple):
                    logits = [logits]
                logits = logits[self.gen_output_index]
                ce_loss += torch.nn.functional.cross_entropy(logits, labels).detach()
                t1, t5 = self.accuracy(logits, labels, (1, 5))
                top_1_acc += t1.detach()
                top_5_acc += t5.detach()
                counter += len(hq) / self.batch_sz
        self.model.train()

        return {"val_cross_entropy": ce_loss / counter,
                "top_5_accuracy": top_5_acc / counter,
                "top_1_accuracy": top_1_acc / counter }


if __name__ == '__main__':
    from torchvision.models import resnet50
    model = resnet50(pretrained=True).to('cuda')
    opt = {
        'batch_size': 128,
        'gen_index': 0,
        'masking': False
    }
    env = {
        'device': 'cuda',

    }
    eval = CategorizationLossEvaluator(model, opt, env)
    print(eval.perform_eval())