import numpy as np
import lmdb
import torch
import torch.utils.data as data
import data.util as util
import torchvision.transforms.functional as F
from PIL import Image
import os.path as osp


class LQDataset(data.Dataset):
    '''Read LQ images only in the test phase.'''

    def __init__(self, opt):
        super(LQDataset, self).__init__()
        self.opt = opt
        self.data_type = self.opt['data_type']
        if 'start_at' in self.opt.keys():
            self.start_at = self.opt['start_at']
        else:
            self.start_at = 0
        self.vertical_splits = self.opt['vertical_splits']
        self.paths_LQ, self.paths_GT = None, None
        self.LQ_env = None  # environment for lmdb

        self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
        self.paths_LQ = self.paths_LQ[self.start_at:]
        assert self.paths_LQ, 'Error: LQ paths are empty.'

    def _init_lmdb(self):
        self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
                                meminit=False)

    def __getitem__(self, index):
        if self.data_type == 'lmdb' and self.LQ_env is None:
            self._init_lmdb()
        if self.vertical_splits > 0:
            actual_index = int(index / self.vertical_splits)
        else:
            actual_index = index

        # get LQ image
        LQ_path = self.paths_LQ[actual_index]
        img_LQ = Image.open(LQ_path)
        if self.vertical_splits > 0:
            w, h = img_LQ.size
            split_index = (index % self.vertical_splits)
            w_per_split = int(w / self.vertical_splits)
            left = w_per_split * split_index
            img_LQ = F.crop(img_LQ, 0, left, h, w_per_split)
        img_LQ = F.to_tensor(img_LQ)

        img_name = osp.splitext(osp.basename(LQ_path))[0]
        LQ_path = LQ_path.replace(img_name, img_name + "_%i" % (index % self.vertical_splits))

        return {'LQ': img_LQ, 'LQ_path': LQ_path}

    def __len__(self):
        if self.vertical_splits > 0:
            return len(self.paths_LQ) * self.vertical_splits
        else:
            return len(self.paths_LQ)