forked from mrq/DL-Art-School
67139602f5
Allows bifurcating large images put into the test pipeline This code is fixed and not dynamic. Needs some fixes.
54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
"""create dataset and dataloader"""
|
|
import logging
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
|
|
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|
phase = dataset_opt['phase']
|
|
if phase == 'train':
|
|
if opt['dist']:
|
|
world_size = torch.distributed.get_world_size()
|
|
num_workers = dataset_opt['n_workers']
|
|
assert dataset_opt['batch_size'] % world_size == 0
|
|
batch_size = dataset_opt['batch_size'] // world_size
|
|
shuffle = False
|
|
else:
|
|
num_workers = max(dataset_opt['n_workers'] * len(opt['gpu_ids']), 10)
|
|
batch_size = dataset_opt['batch_size']
|
|
shuffle = True
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
|
pin_memory=False)
|
|
else:
|
|
batch_size = dataset_opt['batch_size'] or 1
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1),
|
|
pin_memory=False)
|
|
|
|
|
|
def create_dataset(dataset_opt):
|
|
mode = dataset_opt['mode']
|
|
# datasets for image restoration
|
|
if mode == 'LQ':
|
|
from data.LQ_dataset import LQDataset as D
|
|
elif mode == 'LQGT':
|
|
from data.LQGT_dataset import LQGTDataset as D
|
|
# datasets for image corruption
|
|
elif mode == 'downsample':
|
|
from data.Downsample_dataset import DownsampleDataset as D
|
|
# datasets for video restoration
|
|
elif mode == 'REDS':
|
|
from data.REDS_dataset import REDSDataset as D
|
|
elif mode == 'Vimeo90K':
|
|
from data.Vimeo90K_dataset import Vimeo90KDataset as D
|
|
elif mode == 'video_test':
|
|
from data.video_test_dataset import VideoTestDataset as D
|
|
else:
|
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
|
dataset = D(dataset_opt)
|
|
|
|
logger = logging.getLogger('base')
|
|
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
|
dataset_opt['name']))
|
|
return dataset
|