DL-Art-School/dlas/data/__init__.py

138 lines
6.1 KiB
Python

"""create dataset and dataloader"""
import torch
import torch.utils.data
from munch import munchify
from dlas.utils.util import opt_get
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True):
phase = dataset_opt['phase']
pin_memory = opt_get(dataset_opt, ['pin_memory'], True)
if phase == 'train':
if opt_get(opt, ['dist'], False):
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
else:
num_workers = dataset_opt['n_workers']
batch_size = dataset_opt['batch_size']
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=pin_memory, collate_fn=collate_fn, persistent_workers=True)
else:
batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
pin_memory=pin_memory, collate_fn=collate_fn)
def create_dataset(dataset_opt, return_collate=False):
mode = dataset_opt['mode']
collate = None
# datasets for image restoration
if mode == 'fullimage':
from dlas.data.images.full_image_dataset import FullImageDataset as D
elif mode == 'single_image_extensible':
from dlas.data.images.single_image_dataset import \
SingleImageDataset as D
elif mode == 'multi_frame_extensible':
from dlas.data.images.multi_frame_dataset import MultiFrameDataset as D
elif mode == 'combined':
from dlas.data.combined_dataset import CombinedDataset as D
elif mode == 'multiscale':
from dlas.data.images.multiscale_dataset import MultiScaleDataset as D
elif mode == 'paired_frame':
from dlas.data.images.paired_frame_dataset import \
PairedFrameDataset as D
elif mode == 'stylegan2':
from dlas.data.images.stylegan2_dataset import Stylegan2Dataset as D
elif mode == 'imagefolder':
from dlas.data.images.image_folder_dataset import \
ImageFolderDataset as D
elif mode == 'torch_dataset':
from data.torch_dataset import TorchDataset as D
elif mode == 'byol_dataset':
from dlas.data.images.byol_attachment import ByolDatasetWrapper as D
elif mode == 'byol_structured_dataset':
from dlas.data.images.byol_attachment import \
StructuredCropDatasetWrapper as D
elif mode == 'random_aug_wrapper':
from dlas.data.images.byol_attachment import \
DatasetRandomAugWrapper as D
elif mode == 'random_dataset':
from dlas.data.images.random_dataset import RandomDataset as D
elif mode == 'zipfile':
from dlas.data.images.zip_file_dataset import ZipFileDataset as D
elif mode == 'nv_tacotron':
from dlas.data.audio.nv_tacotron_dataset import TextMelCollate as C
from dlas.data.audio.nv_tacotron_dataset import TextWavLoader as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
if opt_get(dataset_opt, ['needs_collate'], True):
collate = C()
elif mode == 'paired_voice_audio':
from dlas.data.audio.paired_voice_audio_dataset import \
TextWavLoader as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'fast_paired_voice_audio':
from dlas.data.audio.fast_paired_dataset import \
FastPairedVoiceDataset as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'fast_paired_voice_audio_with_phonemes':
from dlas.data.audio.fast_paired_dataset_with_phonemes import \
FastPairedVoiceDataset as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'gpt_tts':
from dlas.data.audio.gpt_tts_dataset import GptTtsCollater as C
from dlas.data.audio.gpt_tts_dataset import GptTtsDataset as D
collate = C(dataset_opt)
elif mode == 'unsupervised_audio':
from dlas.data.audio.unsupervised_audio_dataset import \
UnsupervisedAudioDataset as D
elif mode == 'unsupervised_audio_with_noise':
from dlas.data.audio.audio_with_noise_dataset import \
AudioWithNoiseDataset as D
elif mode == 'preprocessed_mel':
from dlas.data.audio.preprocessed_mel_dataset import \
PreprocessedMelDataset as D
elif mode == 'grand_conjoined_voice':
from dlas.data.audio.grand_conjoined_dataset import \
GrandConjoinedDataset as D
from dlas.data.zero_pad_dict_collate import ZeroPadDictCollate as C
if opt_get(dataset_opt, ['needs_collate'], False):
collate = C()
else:
raise NotImplementedError(
'Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
if return_collate:
return dataset, collate
else:
return dataset
def get_dataset_debugger(dataset_opt):
mode = dataset_opt['mode']
if mode == 'paired_voice_audio':
from dlas.data.audio.paired_voice_audio_dataset import \
PairedVoiceDebugger
return PairedVoiceDebugger()
elif mode == 'fast_paired_voice_audio':
from dlas.data.audio.fast_paired_dataset import FastPairedVoiceDebugger
return FastPairedVoiceDebugger()
return None