"""create dataset and dataloader""" import logging import torch import torch.utils.data def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): phase = dataset_opt['phase'] if phase == 'train': if opt['dist']: world_size = torch.distributed.get_world_size() num_workers = dataset_opt['n_workers'] assert dataset_opt['batch_size'] % world_size == 0 batch_size = dataset_opt['batch_size'] // world_size shuffle = False else: num_workers = max(dataset_opt['n_workers'] * len(opt['gpu_ids']), 10) batch_size = dataset_opt['batch_size'] shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=True) else: batch_size = dataset_opt['batch_size'] or 1 return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=max(int(batch_size/2), 1), pin_memory=True) def create_dataset(dataset_opt): mode = dataset_opt['mode'] # datasets for image restoration if mode == 'LQ': from data.LQ_dataset import LQDataset as D elif mode == 'LQGT': from data.LQGT_dataset import LQGTDataset as D # datasets for image corruption elif mode == 'downsample': from data.Downsample_dataset import DownsampleDataset as D elif mode == 'fullimage': from data.full_image_dataset import FullImageDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) logger = logging.getLogger('base') logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, dataset_opt['name'])) return dataset