''' REDS dataset support reading images from lmdb, image folder and memcached ''' import os.path as osp import random import pickle import logging import numpy as np import cv2 import lmdb import torch import torch.utils.data as data import data.util as util try: import mc # import memcached except ImportError: pass logger = logging.getLogger('base') class REDSDataset(data.Dataset): ''' Reading the training REDS dataset key example: 000_00000000 GT: Ground-Truth; LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames support reading N LQ frames, N = 1, 3, 5, 7 ''' def __init__(self, opt): super(REDSDataset, self).__init__() self.opt = opt # temporal augmentation self.interval_list = opt['interval_list'] self.random_reverse = opt['random_reverse'] logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) self.half_N_frames = opt['N_frames'] // 2 self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] self.data_type = self.opt['data_type'] self.LR_input = False if opt['target_size'] == opt['LQ_size'] else True # low resolution inputs #### directly load image keys if self.data_type == 'lmdb': self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) logger.info('Using lmdb meta info for cache keys.') elif opt['cache_keys']: logger.info('Using cache keys: {}'.format(opt['cache_keys'])) self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] else: raise ValueError( 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') # remove the REDS4 for testing self.paths_GT = [ v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020'] ] assert self.paths_GT, 'Error: GT path is empty.' if self.data_type == 'lmdb': self.GT_env, self.LQ_env = None, None elif self.data_type == 'mc': # memcached self.mclient = None elif self.data_type == 'img': pass else: raise ValueError('Wrong data type: {}'.format(self.data_type)) def _init_lmdb(self): # https://github.com/chainer/chainermn/issues/129 self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=False) self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) def _ensure_memcached(self): if self.mclient is None: # specify the config files server_list_config_file = None client_config_file = None self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) def _read_img_mc(self, path): ''' Return BGR, HWC, [0, 255], uint8''' value = mc.pyvector() self.mclient.Get(path, value) value_buf = mc.ConvertBuffer(value) img_array = np.frombuffer(value_buf, np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) return img def _read_img_mc_BGR(self, path, name_a, name_b): ''' Read BGR channels separately and then combine for 1M limits in cluster''' img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png')) img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png')) img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png')) img = cv2.merge((img_B, img_G, img_R)) return img def __getitem__(self, index): if self.data_type == 'mc': self._ensure_memcached() elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): self._init_lmdb() scale = self.opt['scale'] GT_size = self.opt['target_size'] key = self.paths_GT[index] name_a, name_b = key.split('_') center_frame_idx = int(name_b) #### determine the neighbor frames interval = random.choice(self.interval_list) if self.opt['border_mode']: direction = 1 # 1: forward; 0: backward N_frames = self.opt['N_frames'] if self.random_reverse and random.random() < 0.5: direction = random.choice([0, 1]) if center_frame_idx + interval * (N_frames - 1) > 99: direction = 0 elif center_frame_idx - interval * (N_frames - 1) < 0: direction = 1 # get the neighbor list if direction == 1: neighbor_list = list( range(center_frame_idx, center_frame_idx + interval * N_frames, interval)) else: neighbor_list = list( range(center_frame_idx, center_frame_idx - interval * N_frames, -interval)) name_b = '{:08d}'.format(neighbor_list[0]) else: # ensure not exceeding the borders while (center_frame_idx + self.half_N_frames * interval > 99) or (center_frame_idx - self.half_N_frames * interval < 0): center_frame_idx = random.randint(0, 99) # get the neighbor list neighbor_list = list( range(center_frame_idx - self.half_N_frames * interval, center_frame_idx + self.half_N_frames * interval + 1, interval)) if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() name_b = '{:08d}'.format(neighbor_list[self.half_N_frames]) assert len( neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format( len(neighbor_list)) #### get the GT image (as the center frame) if self.data_type == 'mc': img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b) img_GT = img_GT.astype(np.float32) / 255. elif self.data_type == 'lmdb': img_GT = util.read_img(self.GT_env, key, (3, 720, 1280)) else: img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png')) #### get LQ images LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280) img_LQ_l = [] for v in neighbor_list: img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v)) if self.data_type == 'mc': if self.LR_input: img_LQ = self._read_img_mc(img_LQ_path) else: img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v)) img_LQ = img_LQ.astype(np.float32) / 255. elif self.data_type == 'lmdb': img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple) else: img_LQ = util.read_img(None, img_LQ_path) img_LQ_l.append(img_LQ) if self.opt['phase'] == 'train': C, H, W = LQ_size_tuple # LQ size # randomly crop if self.LR_input: LQ_size = GT_size // scale rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size)) img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] else: rnd_h = random.randint(0, max(0, H - GT_size)) rnd_w = random.randint(0, max(0, W - GT_size)) img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] # augmentation - flip, rotate img_LQ_l.append(img_GT) rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) img_LQ_l = rlt[0:-1] img_GT = rlt[-1] # stack LQ images to NHWC, N is the frame number img_LQs = np.stack(img_LQ_l, axis=0) # BGR to RGB, HWC to CHW, numpy to tensor img_GT = img_GT[:, :, [2, 1, 0]] img_LQs = img_LQs[:, :, :, [2, 1, 0]] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, (0, 3, 1, 2)))).float() return {'LQs': img_LQs, 'GT': img_GT, 'key': key} def __len__(self): return len(self.paths_GT)