"""create dataset and dataloader"""
import logging
import torch
import torch.utils.data


def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
    phase = dataset_opt['phase']
    if phase == 'train':
        if opt['dist']:
            world_size = torch.distributed.get_world_size()
            num_workers = dataset_opt['n_workers']
            assert dataset_opt['batch_size'] % world_size == 0
            batch_size = dataset_opt['batch_size'] // world_size
            shuffle = False
        else:
            num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
            batch_size = dataset_opt['batch_size']
            shuffle = True
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                                           num_workers=num_workers, sampler=sampler, drop_last=True,
                                           pin_memory=True)
    else:
        batch_size = dataset_opt['batch_size'] or 1
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
                                           pin_memory=True)


def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    # datasets for image restoration
    if mode == 'fullimage':
        from data.full_image_dataset import FullImageDataset as D
    elif mode == 'single_image_extensible':
        from data.single_image_dataset import SingleImageDataset as D
    elif mode == 'multi_frame_extensible':
        from data.multi_frame_dataset import MultiFrameDataset as D
    elif mode == 'combined':
        from data.combined_dataset import CombinedDataset as D
    elif mode == 'multiscale':
        from data.multiscale_dataset import MultiScaleDataset as D
    elif mode == 'paired_frame':
        from data.paired_frame_dataset import PairedFrameDataset as D
    elif mode == 'stylegan2':
        from data.stylegan2_dataset import Stylegan2Dataset as D
    elif mode == 'imagefolder':
        from data.image_folder_dataset import ImageFolderDataset as D
    elif mode == 'torch_dataset':
        from data.torch_dataset import TorchDataset as D
    elif mode == 'byol_dataset':
        from data.byol_attachment import ByolDatasetWrapper as D
    elif mode == 'byol_structured_dataset':
        from data.byol_attachment import StructuredCropDatasetWrapper as D
    elif mode == 'random_aug_wrapper':
        from data.byol_attachment import DatasetRandomAugWrapper as D
    elif mode == 'random_dataset':
        from data.random_dataset import RandomDataset as D
    else:
        raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    return dataset