import random
import numpy as np
import cv2
import lmdb
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image
from io import BytesIO
import torchvision.transforms.functional as F


class DownsampleDataset(data.Dataset):
    """
    Reads an unpaired HQ and LQ image. Clips both images to the expected input sizes of the model. Produces a
    downsampled LQ image from the HQ image and feeds that as well.
    """

    def __init__(self, opt):
        super(DownsampleDataset, self).__init__()
        self.opt = opt
        self.data_type = self.opt['data_type']
        self.paths_LQ, self.paths_GT = None, None
        self.sizes_LQ, self.sizes_GT = None, None
        self.LQ_env, self.GT_env = None, None  # environments for lmdb
        self.doCrop = self.opt['doCrop']

        self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
        self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])

        self.data_sz_mismatch_ok = opt['mismatched_Data_OK']
        assert self.paths_GT, 'Error: GT path is empty.'
        assert self.paths_LQ, 'LQ is required for downsampling.'
        if not self.data_sz_mismatch_ok:
            assert len(self.paths_LQ) == len(
                self.paths_GT
            ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
                len(self.paths_LQ), len(self.paths_GT))
        self.random_scale_list = [1]

    def _init_lmdb(self):
        # https://github.com/chainer/chainermn/issues/129
        self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
                                meminit=False)
        self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
                                meminit=False)

    def __getitem__(self, index):
        if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
            self._init_lmdb()
        scale = self.opt['scale']
        GT_size = self.opt['target_size'] * scale

        # get GT image
        GT_path = self.paths_GT[index % len(self.paths_GT)]
        resolution = [int(s) for s in self.sizes_GT[index].split('_')
                      ] if self.data_type == 'lmdb' else None
        img_GT = util.read_img(self.GT_env, GT_path, resolution)
        if self.opt['phase'] != 'train':  # modcrop in the validation / test phase
            img_GT = util.modcrop(img_GT, scale)
        if self.opt['color']:  # change color space if necessary
            img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]

        # get LQ image
        LQ_path = self.paths_LQ[index % len(self.paths_LQ)]
        resolution = [int(s) for s in self.sizes_LQ[index].split('_')
                      ] if self.data_type == 'lmdb' else None
        img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)

        if self.opt['phase'] == 'train':
            H, W, _ = img_GT.shape
            assert H >= GT_size and W >= GT_size

            H, W, C = img_LQ.shape
            LQ_size = GT_size // scale

            if self.doCrop:
                # randomly crop
                rnd_h = random.randint(0, max(0, H - LQ_size))
                rnd_w = random.randint(0, max(0, W - LQ_size))
                img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
                rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
                img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
            else:
                img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
                img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)

            # augmentation - flip, rotate
            img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
                                          self.opt['use_rot'])

        # BGR to RGB, HWC to CHW, numpy to tensor
        if img_GT.shape[2] == 3:
            img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
            img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)

        # HQ needs to go to a PIL image to perform the compression-artifact transformation.
        H, W, _ = img_GT.shape
        img_GT = (img_GT * 255).astype(np.uint8)
        img_GT = Image.fromarray(img_GT)
        if self.opt['use_compression_artifacts']:
            qf = random.randrange(15, 100)
            corruption_buffer = BytesIO()
            img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
            corruption_buffer.seek(0)
            img_GT = Image.open(corruption_buffer)
        # Generate a downsampled image from HQ for feature and PIX losses.
        img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))

        img_GT = F.to_tensor(img_GT)
        img_Downsampled = F.to_tensor(img_Downsampled)
        img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()

        # This may seem really messed up, but let me explain:
        #  The goal is to re-use existing code as much as possible. SRGAN_model was coded to supersample, not downsample,
        #  but it can be retrofitted. To do so, we need to "trick" it. In this case the "input" is the HQ image and the
        #  "output" is the LQ image. SRGAN_model will be using a Generator and a Discriminator which already know this,
        #  we just need to trick its logic into following this rules.
        #  Do this by setting LQ(which is the input into the models)=img_GT and GT(which is the expected outpuut)=img_LQ.
        #  PIX is used as a reference for the pixel loss. Use the manually downsampled image for this.
        return {'LQ': img_GT, 'GT': img_LQ, 'PIX': img_Downsampled, 'LQ_path': LQ_path, 'GT_path': GT_path}

    def __len__(self):
        return max(len(self.paths_GT), len(self.paths_LQ))