forked from mrq/DL-Art-School
24792bdb4f
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
import sys
|
|
import os.path as osp
|
|
import math
|
|
import torchvision.utils
|
|
|
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
from data import create_dataloader, create_dataset # noqa: E402
|
|
from utils import util # noqa: E402
|
|
|
|
|
|
def main():
|
|
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
|
|
opt = {}
|
|
opt['dist'] = False
|
|
opt['gpu_ids'] = [0]
|
|
if dataset == 'REDS':
|
|
opt['name'] = 'test_REDS'
|
|
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
|
opt['mode'] = 'REDS'
|
|
opt['N_frames'] = 5
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 256
|
|
opt['LQ_size'] = 64
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['interval_list'] = [1]
|
|
opt['random_reverse'] = False
|
|
opt['border_mode'] = False
|
|
opt['cache_keys'] = None
|
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
|
elif dataset == 'Vimeo90K':
|
|
opt['name'] = 'test_Vimeo90K'
|
|
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
|
opt['mode'] = 'Vimeo90K'
|
|
opt['N_frames'] = 7
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 256
|
|
opt['LQ_size'] = 64
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['interval_list'] = [1]
|
|
opt['random_reverse'] = False
|
|
opt['border_mode'] = False
|
|
opt['cache_keys'] = None
|
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
|
elif dataset == 'DIV2K800_sub':
|
|
opt['name'] = 'DIV2K800'
|
|
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
|
opt['mode'] = 'LQGT'
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 128
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['color'] = 'RGB'
|
|
opt['data_type'] = 'lmdb' # img | lmdb
|
|
else:
|
|
raise ValueError('Please implement by yourself.')
|
|
|
|
util.mkdir('tmp')
|
|
train_set = create_dataset(opt)
|
|
train_loader = create_dataloader(train_set, opt, opt, None)
|
|
nrow = int(math.sqrt(opt['batch_size']))
|
|
padding = 2 if opt['phase'] == 'train' else 0
|
|
|
|
print('start...')
|
|
for i, data in enumerate(train_loader):
|
|
if i > 5:
|
|
break
|
|
print(i)
|
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
|
LQs = data['LQs']
|
|
else:
|
|
LQ = data['LQ']
|
|
GT = data['GT']
|
|
|
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
|
for j in range(LQs.size(1)):
|
|
torchvision.utils.save_image(LQs[:, j, :, :, :],
|
|
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
|
|
padding=padding, normalize=False)
|
|
else:
|
|
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
|
|
padding=padding, normalize=False)
|
|
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
|
|
normalize=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|