DL-Art-School/codes/scripts/test_dataloader.py
James Betker 11155aead4 Directly use dataset keys
This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
2020-12-04 20:14:53 -07:00

105 lines
3.6 KiB
Python

import sys
import os.path as osp
import math
import torchvision.utils
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from data import create_dataloader, create_dataset # noqa: E402
from utils import util # noqa: E402
def main():
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
opt = {}
opt['dist'] = False
opt['gpu_ids'] = [0]
if dataset == 'REDS':
opt['name'] = 'test_REDS'
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
opt['mode'] = 'REDS'
opt['N_frames'] = 5
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'Vimeo90K':
opt['name'] = 'test_Vimeo90K'
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
opt['mode'] = 'Vimeo90K'
opt['N_frames'] = 7
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 256
opt['LQ_size'] = 64
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['interval_list'] = [1]
opt['random_reverse'] = False
opt['border_mode'] = False
opt['cache_keys'] = None
opt['data_type'] = 'lmdb' # img | lmdb | mc
elif dataset == 'DIV2K800_sub':
opt['name'] = 'DIV2K800'
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
opt['mode'] = 'LQGT'
opt['phase'] = 'train'
opt['use_shuffle'] = True
opt['n_workers'] = 8
opt['batch_size'] = 16
opt['target_size'] = 128
opt['scale'] = 4
opt['use_flip'] = True
opt['use_rot'] = True
opt['color'] = 'RGB'
opt['data_type'] = 'lmdb' # img | lmdb
else:
raise ValueError('Please implement by yourself.')
util.mkdir('tmp')
train_set = create_dataset(opt)
train_loader = create_dataloader(train_set, opt, opt, None)
nrow = int(math.sqrt(opt['batch_size']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(train_loader):
if i > 5:
break
print(i)
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['lq']
GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):
torchvision.utils.save_image(LQs[:, j, :, :, :],
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
padding=padding, normalize=False)
else:
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
padding=padding, normalize=False)
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
normalize=False)
if __name__ == "__main__":
main()