From d7f30232c363c2deea6dca58c6d88c6ddbb61525 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 16 Aug 2021 22:52:15 -0600 Subject: [PATCH] Oh yeah --- codes/data/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 4f0c0c7c..aa74d6e6 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -7,7 +7,7 @@ from munch import munchify from utils.util import opt_get -def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None): +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True): phase = dataset_opt['phase'] if phase == 'train': if opt_get(opt, ['dist'], False): @@ -15,11 +15,9 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=N num_workers = dataset_opt['n_workers'] assert dataset_opt['batch_size'] % world_size == 0 batch_size = dataset_opt['batch_size'] // world_size - shuffle = False else: num_workers = dataset_opt['n_workers'] batch_size = dataset_opt['batch_size'] - shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=True, collate_fn=collate_fn) @@ -77,6 +75,12 @@ def create_dataset(dataset_opt, return_collate=False): collate = C(dataset_opt) elif mode == 'wavfile_clips': from data.audio.wavfile_dataset import WavfileDataset as D + elif mode == 'stop_prediction': + from models.tacotron2.hparams import create_hparams + default_params = create_hparams() + default_params.update(dataset_opt) + dataset_opt = munchify(default_params) + from data.audio.stop_prediction_dataset import StopPredictionDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt)