import os.path as osp import torch import torch.utils.data as data import data.util as util class VideoTestDataset(data.Dataset): """ A video test dataset. Support: Vid4 REDS4 Vimeo90K-Test no need to prepare LMDB files """ def __init__(self, opt): super(VideoTestDataset, self).__init__() self.opt = opt self.cache_data = opt['cache_data'] self.half_N_frames = opt['N_frames'] // 2 self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] self.data_type = self.opt['data_type'] self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []} if self.data_type == 'lmdb': raise ValueError('No need to use LMDB during validation/test.') #### Generate data info and cache data self.imgs_LQ, self.imgs_GT = {}, {} if opt['name'].lower() in ['vid4', 'reds4']: subfolders_LQ = util.glob_file_list(self.LQ_root) subfolders_GT = util.glob_file_list(self.GT_root) for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): subfolder_name = osp.basename(subfolder_GT) img_paths_LQ = util.glob_file_list(subfolder_LQ) img_paths_GT = util.glob_file_list(subfolder_GT) max_idx = len(img_paths_LQ) assert max_idx == len( img_paths_GT), 'Different number of images in LQ and GT folders' self.data_info['path_LQ'].extend(img_paths_LQ) self.data_info['path_GT'].extend(img_paths_GT) self.data_info['folder'].extend([subfolder_name] * max_idx) for i in range(max_idx): self.data_info['idx'].append('{}/{}'.format(i, max_idx)) border_l = [0] * max_idx for i in range(self.half_N_frames): border_l[i] = 1 border_l[max_idx - i - 1] = 1 self.data_info['border'].extend(border_l) if self.cache_data: self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ) self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT) elif opt['name'].lower() in ['vimeo90k-test']: pass # TODO else: raise ValueError( 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.') def __getitem__(self, index): # path_LQ = self.data_info['path_LQ'][index] # path_GT = self.data_info['path_GT'][index] folder = self.data_info['folder'][index] idx, max_idx = self.data_info['idx'][index].split('/') idx, max_idx = int(idx), int(max_idx) border = self.data_info['border'][index] if self.cache_data: select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'], padding=self.opt['padding']) imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx)) img_GT = self.imgs_GT[folder][idx] else: pass # TODO return { 'LQs': imgs_LQ, 'GT': img_GT, 'folder': folder, 'idx': self.data_info['idx'][index], 'border': border } def __len__(self): return len(self.data_info['path_GT'])