DL-Art-School/codes/data/__init__.py

63 lines
2.7 KiB
Python
Raw Normal View History

2019-08-23 13:42:47 +00:00
"""create dataset and dataloader"""
import logging
import torch
import torch.utils.data
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
phase = dataset_opt['phase']
if phase == 'train':
if opt['dist']:
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
shuffle = False
else:
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
2019-08-23 13:42:47 +00:00
batch_size = dataset_opt['batch_size']
shuffle = True
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, sampler=sampler, drop_last=True,
2020-09-04 21:30:46 +00:00
pin_memory=True)
2019-08-23 13:42:47 +00:00
else:
batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
2020-09-04 21:30:46 +00:00
pin_memory=True)
2019-08-23 13:42:47 +00:00
def create_dataset(dataset_opt):
mode = dataset_opt['mode']
# datasets for image restoration
if mode == 'fullimage':
2020-08-25 17:56:59 +00:00
from data.full_image_dataset import FullImageDataset as D
2020-09-26 04:19:38 +00:00
elif mode == 'single_image_extensible':
from data.single_image_dataset import SingleImageDataset as D
2020-09-28 20:26:15 +00:00
elif mode == 'multi_frame_extensible':
from data.multi_frame_dataset import MultiFrameDataset as D
elif mode == 'combined':
from data.combined_dataset import CombinedDataset as D
2020-10-18 04:54:12 +00:00
elif mode == 'multiscale':
from data.multiscale_dataset import MultiScaleDataset as D
2020-10-24 02:58:07 +00:00
elif mode == 'paired_frame':
from data.paired_frame_dataset import PairedFrameDataset as D
2020-11-12 22:42:05 +00:00
elif mode == 'stylegan2':
from data.stylegan2_dataset import Stylegan2Dataset as D
elif mode == 'imagefolder':
from data.image_folder_dataset import ImageFolderDataset as D
2020-12-03 22:32:21 +00:00
elif mode == 'torch_dataset':
from data.torch_dataset import TorchDataset as D
elif mode == 'byol_dataset':
from data.byol_attachment import ByolDatasetWrapper as D
2020-12-10 22:07:35 +00:00
elif mode == 'byol_structured_dataset':
from data.byol_attachment import StructuredCropDatasetWrapper as D
elif mode == 'random_aug_wrapper':
from data.byol_attachment import DatasetRandomAugWrapper as D
2020-12-09 21:55:05 +00:00
elif mode == 'random_dataset':
from data.random_dataset import RandomDataset as D
2019-08-23 13:42:47 +00:00
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
return dataset