forked from mrq/DL-Art-School
168 lines
6.7 KiB
Python
168 lines
6.7 KiB
Python
|
'''
|
||
|
Vimeo90K dataset
|
||
|
support reading images from lmdb, image folder and memcached
|
||
|
'''
|
||
|
import os.path as osp
|
||
|
import random
|
||
|
import pickle
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
import lmdb
|
||
|
import torch
|
||
|
import torch.utils.data as data
|
||
|
import data.util as util
|
||
|
try:
|
||
|
import mc # import memcached
|
||
|
except ImportError:
|
||
|
pass
|
||
|
logger = logging.getLogger('base')
|
||
|
|
||
|
|
||
|
class Vimeo90KDataset(data.Dataset):
|
||
|
'''
|
||
|
Reading the training Vimeo90K dataset
|
||
|
key example: 00001_0001 (_1, ..., _7)
|
||
|
GT (Ground-Truth): 4th frame;
|
||
|
LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
|
||
|
'''
|
||
|
|
||
|
def __init__(self, opt):
|
||
|
super(Vimeo90KDataset, self).__init__()
|
||
|
self.opt = opt
|
||
|
# temporal augmentation
|
||
|
self.interval_list = opt['interval_list']
|
||
|
self.random_reverse = opt['random_reverse']
|
||
|
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
||
|
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
||
|
|
||
|
self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
|
||
|
self.data_type = self.opt['data_type']
|
||
|
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
|
||
|
|
||
|
#### determine the LQ frame list
|
||
|
'''
|
||
|
N | frames
|
||
|
1 | 4
|
||
|
3 | 3,4,5
|
||
|
5 | 2,3,4,5,6
|
||
|
7 | 1,2,3,4,5,6,7
|
||
|
'''
|
||
|
self.LQ_frames_list = []
|
||
|
for i in range(opt['N_frames']):
|
||
|
self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
|
||
|
|
||
|
#### directly load image keys
|
||
|
if self.data_type == 'lmdb':
|
||
|
self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||
|
logger.info('Using lmdb meta info for cache keys.')
|
||
|
elif opt['cache_keys']:
|
||
|
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
|
||
|
self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
|
||
|
assert self.paths_GT, 'Error: GT path is empty.'
|
||
|
|
||
|
if self.data_type == 'lmdb':
|
||
|
self.GT_env, self.LQ_env = None, None
|
||
|
elif self.data_type == 'mc': # memcached
|
||
|
self.mclient = None
|
||
|
elif self.data_type == 'img':
|
||
|
pass
|
||
|
else:
|
||
|
raise ValueError('Wrong data type: {}'.format(self.data_type))
|
||
|
|
||
|
def _init_lmdb(self):
|
||
|
# https://github.com/chainer/chainermn/issues/129
|
||
|
self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
|
||
|
meminit=False)
|
||
|
self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
|
||
|
meminit=False)
|
||
|
|
||
|
def _ensure_memcached(self):
|
||
|
if self.mclient is None:
|
||
|
# specify the config files
|
||
|
server_list_config_file = None
|
||
|
client_config_file = None
|
||
|
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
||
|
client_config_file)
|
||
|
|
||
|
def _read_img_mc(self, path):
|
||
|
''' Return BGR, HWC, [0, 255], uint8'''
|
||
|
value = mc.pyvector()
|
||
|
self.mclient.Get(path, value)
|
||
|
value_buf = mc.ConvertBuffer(value)
|
||
|
img_array = np.frombuffer(value_buf, np.uint8)
|
||
|
img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
|
||
|
return img
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
if self.data_type == 'mc':
|
||
|
self._ensure_memcached()
|
||
|
elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
|
||
|
self._init_lmdb()
|
||
|
|
||
|
scale = self.opt['scale']
|
||
|
GT_size = self.opt['GT_size']
|
||
|
key = self.paths_GT[index]
|
||
|
name_a, name_b = key.split('_')
|
||
|
#### get the GT image (as the center frame)
|
||
|
if self.data_type == 'mc':
|
||
|
img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
|
||
|
img_GT = img_GT.astype(np.float32) / 255.
|
||
|
elif self.data_type == 'lmdb':
|
||
|
img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
|
||
|
else:
|
||
|
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
|
||
|
|
||
|
#### get LQ images
|
||
|
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
|
||
|
img_LQ_l = []
|
||
|
for v in self.LQ_frames_list:
|
||
|
if self.data_type == 'mc':
|
||
|
img_LQ = self._read_img_mc(
|
||
|
osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
|
||
|
img_LQ = img_LQ.astype(np.float32) / 255.
|
||
|
elif self.data_type == 'lmdb':
|
||
|
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
|
||
|
else:
|
||
|
img_LQ = util.read_img(None,
|
||
|
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
|
||
|
img_LQ_l.append(img_LQ)
|
||
|
|
||
|
if self.opt['phase'] == 'train':
|
||
|
C, H, W = LQ_size_tuple # LQ size
|
||
|
# randomly crop
|
||
|
if self.LR_input:
|
||
|
LQ_size = GT_size // scale
|
||
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||
|
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
|
||
|
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||
|
img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
|
||
|
else:
|
||
|
rnd_h = random.randint(0, max(0, H - GT_size))
|
||
|
rnd_w = random.randint(0, max(0, W - GT_size))
|
||
|
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
|
||
|
img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
||
|
|
||
|
# augmentation - flip, rotate
|
||
|
img_LQ_l.append(img_GT)
|
||
|
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
|
||
|
img_LQ_l = rlt[0:-1]
|
||
|
img_GT = rlt[-1]
|
||
|
|
||
|
# stack LQ images to NHWC, N is the frame number
|
||
|
img_LQs = np.stack(img_LQ_l, axis=0)
|
||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||
|
img_GT = img_GT[:, :, [2, 1, 0]]
|
||
|
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
|
||
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||
|
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
|
||
|
(0, 3, 1, 2)))).float()
|
||
|
return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.paths_GT)
|