2019-08-23 13:42:47 +00:00
|
|
|
"""create dataset and dataloader"""
|
|
|
|
import logging
|
|
|
|
import torch
|
|
|
|
import torch.utils.data
|
2021-07-06 17:11:35 +00:00
|
|
|
from munch import munchify
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2021-06-11 21:31:10 +00:00
|
|
|
from utils.util import opt_get
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2021-08-17 04:52:15 +00:00
|
|
|
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True):
|
2019-08-23 13:42:47 +00:00
|
|
|
phase = dataset_opt['phase']
|
|
|
|
if phase == 'train':
|
2021-06-11 21:31:10 +00:00
|
|
|
if opt_get(opt, ['dist'], False):
|
2019-08-23 13:42:47 +00:00
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
num_workers = dataset_opt['n_workers']
|
|
|
|
assert dataset_opt['batch_size'] % world_size == 0
|
|
|
|
batch_size = dataset_opt['batch_size'] // world_size
|
|
|
|
else:
|
2021-06-11 21:31:10 +00:00
|
|
|
num_workers = dataset_opt['n_workers']
|
2019-08-23 13:42:47 +00:00
|
|
|
batch_size = dataset_opt['batch_size']
|
|
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
|
|
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
2021-07-06 17:11:35 +00:00
|
|
|
pin_memory=True, collate_fn=collate_fn)
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
2020-05-04 14:48:25 +00:00
|
|
|
batch_size = dataset_opt['batch_size'] or 1
|
2020-10-26 17:12:22 +00:00
|
|
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
2021-07-06 17:11:35 +00:00
|
|
|
pin_memory=True, collate_fn=collate_fn)
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
|
2021-07-06 17:11:35 +00:00
|
|
|
def create_dataset(dataset_opt, return_collate=False):
|
2019-08-23 13:42:47 +00:00
|
|
|
mode = dataset_opt['mode']
|
2021-07-06 17:11:35 +00:00
|
|
|
collate = None
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
# datasets for image restoration
|
2020-10-14 02:56:39 +00:00
|
|
|
if mode == 'fullimage':
|
2020-08-25 17:56:59 +00:00
|
|
|
from data.full_image_dataset import FullImageDataset as D
|
2020-09-26 04:19:38 +00:00
|
|
|
elif mode == 'single_image_extensible':
|
|
|
|
from data.single_image_dataset import SingleImageDataset as D
|
2020-09-28 20:26:15 +00:00
|
|
|
elif mode == 'multi_frame_extensible':
|
|
|
|
from data.multi_frame_dataset import MultiFrameDataset as D
|
2020-09-11 14:44:06 +00:00
|
|
|
elif mode == 'combined':
|
|
|
|
from data.combined_dataset import CombinedDataset as D
|
2020-10-18 04:54:12 +00:00
|
|
|
elif mode == 'multiscale':
|
|
|
|
from data.multiscale_dataset import MultiScaleDataset as D
|
2020-10-24 02:58:07 +00:00
|
|
|
elif mode == 'paired_frame':
|
|
|
|
from data.paired_frame_dataset import PairedFrameDataset as D
|
2020-11-12 22:42:05 +00:00
|
|
|
elif mode == 'stylegan2':
|
|
|
|
from data.stylegan2_dataset import Stylegan2Dataset as D
|
2020-12-02 00:45:37 +00:00
|
|
|
elif mode == 'imagefolder':
|
|
|
|
from data.image_folder_dataset import ImageFolderDataset as D
|
2020-12-03 22:32:21 +00:00
|
|
|
elif mode == 'torch_dataset':
|
|
|
|
from data.torch_dataset import TorchDataset as D
|
2020-12-08 20:07:53 +00:00
|
|
|
elif mode == 'byol_dataset':
|
|
|
|
from data.byol_attachment import ByolDatasetWrapper as D
|
2020-12-10 22:07:35 +00:00
|
|
|
elif mode == 'byol_structured_dataset':
|
|
|
|
from data.byol_attachment import StructuredCropDatasetWrapper as D
|
2020-12-16 00:15:56 +00:00
|
|
|
elif mode == 'random_aug_wrapper':
|
|
|
|
from data.byol_attachment import DatasetRandomAugWrapper as D
|
2020-12-09 21:55:05 +00:00
|
|
|
elif mode == 'random_dataset':
|
|
|
|
from data.random_dataset import RandomDataset as D
|
2021-05-25 03:35:00 +00:00
|
|
|
elif mode == 'zipfile':
|
|
|
|
from data.zip_file_dataset import ZipFileDataset as D
|
2021-07-06 17:11:35 +00:00
|
|
|
elif mode == 'nv_tacotron':
|
2021-10-31 21:01:38 +00:00
|
|
|
from data.audio.nv_tacotron_dataset import TextWavLoader as D
|
2021-07-06 17:11:35 +00:00
|
|
|
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
|
|
|
from models.tacotron2.hparams import create_hparams
|
|
|
|
default_params = create_hparams()
|
2021-07-09 04:13:44 +00:00
|
|
|
default_params.update(dataset_opt)
|
|
|
|
dataset_opt = munchify(default_params)
|
2021-08-12 21:44:55 +00:00
|
|
|
if opt_get(dataset_opt, ['needs_collate'], True):
|
2021-10-31 21:01:38 +00:00
|
|
|
collate = C()
|
2021-12-22 21:03:18 +00:00
|
|
|
elif mode == 'paired_voice_audio':
|
|
|
|
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
|
|
|
from models.tacotron2.hparams import create_hparams
|
|
|
|
default_params = create_hparams()
|
|
|
|
default_params.update(dataset_opt)
|
|
|
|
dataset_opt = munchify(default_params)
|
2021-08-04 06:44:04 +00:00
|
|
|
elif mode == 'gpt_tts':
|
|
|
|
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
|
|
|
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
|
|
|
collate = C(dataset_opt)
|
2021-09-14 23:43:16 +00:00
|
|
|
elif mode == 'unsupervised_audio':
|
|
|
|
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
|
2021-10-24 15:09:34 +00:00
|
|
|
elif mode == 'unsupervised_audio_with_noise':
|
|
|
|
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
2021-12-23 21:32:33 +00:00
|
|
|
elif mode == 'grand_conjoined_voice':
|
|
|
|
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
2021-12-29 16:44:37 +00:00
|
|
|
from data.zero_pad_dict_collate import ZeroPadDictCollate as C
|
2022-01-01 21:25:27 +00:00
|
|
|
if opt_get(dataset_opt, ['needs_collate'], False):
|
2021-12-29 16:44:37 +00:00
|
|
|
collate = C()
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
|
|
|
dataset = D(dataset_opt)
|
|
|
|
|
2021-07-06 17:11:35 +00:00
|
|
|
if return_collate:
|
|
|
|
return dataset, collate
|
|
|
|
else:
|
|
|
|
return dataset
|
2022-01-06 19:38:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_debugger(dataset_opt):
|
|
|
|
mode = dataset_opt['mode']
|
|
|
|
if mode == 'paired_voice_audio':
|
|
|
|
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
|
|
|
return PairedVoiceDebugger()
|
|
|
|
return None
|