forked from mrq/DL-Art-School
Oh yeah
This commit is contained in:
parent
4c01d82265
commit
d7f30232c3
|
@ -7,7 +7,7 @@ from munch import munchify
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None):
|
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True):
|
||||||
phase = dataset_opt['phase']
|
phase = dataset_opt['phase']
|
||||||
if phase == 'train':
|
if phase == 'train':
|
||||||
if opt_get(opt, ['dist'], False):
|
if opt_get(opt, ['dist'], False):
|
||||||
|
@ -15,11 +15,9 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=N
|
||||||
num_workers = dataset_opt['n_workers']
|
num_workers = dataset_opt['n_workers']
|
||||||
assert dataset_opt['batch_size'] % world_size == 0
|
assert dataset_opt['batch_size'] % world_size == 0
|
||||||
batch_size = dataset_opt['batch_size'] // world_size
|
batch_size = dataset_opt['batch_size'] // world_size
|
||||||
shuffle = False
|
|
||||||
else:
|
else:
|
||||||
num_workers = dataset_opt['n_workers']
|
num_workers = dataset_opt['n_workers']
|
||||||
batch_size = dataset_opt['batch_size']
|
batch_size = dataset_opt['batch_size']
|
||||||
shuffle = True
|
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||||
pin_memory=True, collate_fn=collate_fn)
|
pin_memory=True, collate_fn=collate_fn)
|
||||||
|
@ -77,6 +75,12 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
collate = C(dataset_opt)
|
collate = C(dataset_opt)
|
||||||
elif mode == 'wavfile_clips':
|
elif mode == 'wavfile_clips':
|
||||||
from data.audio.wavfile_dataset import WavfileDataset as D
|
from data.audio.wavfile_dataset import WavfileDataset as D
|
||||||
|
elif mode == 'stop_prediction':
|
||||||
|
from models.tacotron2.hparams import create_hparams
|
||||||
|
default_params = create_hparams()
|
||||||
|
default_params.update(dataset_opt)
|
||||||
|
dataset_opt = munchify(default_params)
|
||||||
|
from data.audio.stop_prediction_dataset import StopPredictionDataset as D
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
dataset = D(dataset_opt)
|
dataset = D(dataset_opt)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user