forked from mrq/DL-Art-School
11155aead4
This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
import sys
|
|
import os.path as osp
|
|
import math
|
|
import torchvision.utils
|
|
|
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
from data import create_dataloader, create_dataset # noqa: E402
|
|
from utils import util # noqa: E402
|
|
|
|
|
|
def main():
|
|
dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
|
|
opt = {}
|
|
opt['dist'] = False
|
|
opt['gpu_ids'] = [0]
|
|
if dataset == 'REDS':
|
|
opt['name'] = 'test_REDS'
|
|
opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
|
|
opt['mode'] = 'REDS'
|
|
opt['N_frames'] = 5
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 256
|
|
opt['LQ_size'] = 64
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['interval_list'] = [1]
|
|
opt['random_reverse'] = False
|
|
opt['border_mode'] = False
|
|
opt['cache_keys'] = None
|
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
|
elif dataset == 'Vimeo90K':
|
|
opt['name'] = 'test_Vimeo90K'
|
|
opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
|
|
opt['mode'] = 'Vimeo90K'
|
|
opt['N_frames'] = 7
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 256
|
|
opt['LQ_size'] = 64
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['interval_list'] = [1]
|
|
opt['random_reverse'] = False
|
|
opt['border_mode'] = False
|
|
opt['cache_keys'] = None
|
|
opt['data_type'] = 'lmdb' # img | lmdb | mc
|
|
elif dataset == 'DIV2K800_sub':
|
|
opt['name'] = 'DIV2K800'
|
|
opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
|
|
opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
|
|
opt['mode'] = 'LQGT'
|
|
opt['phase'] = 'train'
|
|
opt['use_shuffle'] = True
|
|
opt['n_workers'] = 8
|
|
opt['batch_size'] = 16
|
|
opt['target_size'] = 128
|
|
opt['scale'] = 4
|
|
opt['use_flip'] = True
|
|
opt['use_rot'] = True
|
|
opt['color'] = 'RGB'
|
|
opt['data_type'] = 'lmdb' # img | lmdb
|
|
else:
|
|
raise ValueError('Please implement by yourself.')
|
|
|
|
util.mkdir('tmp')
|
|
train_set = create_dataset(opt)
|
|
train_loader = create_dataloader(train_set, opt, opt, None)
|
|
nrow = int(math.sqrt(opt['batch_size']))
|
|
padding = 2 if opt['phase'] == 'train' else 0
|
|
|
|
print('start...')
|
|
for i, data in enumerate(train_loader):
|
|
if i > 5:
|
|
break
|
|
print(i)
|
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
|
LQs = data['LQs']
|
|
else:
|
|
LQ = data['lq']
|
|
GT = data['hq']
|
|
|
|
if dataset == 'REDS' or dataset == 'Vimeo90K':
|
|
for j in range(LQs.size(1)):
|
|
torchvision.utils.save_image(LQs[:, j, :, :, :],
|
|
'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
|
|
padding=padding, normalize=False)
|
|
else:
|
|
torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
|
|
padding=padding, normalize=False)
|
|
torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
|
|
normalize=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|