From 277c759ab13fdce4c80c9c2f2c9d42def8698aac Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 14 Aug 2023 21:42:35 -0500 Subject: [PATCH] fixed issue with non-distributed training, oops --- vall_e/data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 4828dca..8e476c1 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -167,7 +167,7 @@ class Dataset(_Dataset): else: self.durations[spkr_id] += duration - if training: + if training and not cfg.distributed: self.sampler = Sampler(self.paths, [cfg.get_spkr]) else: self.sampler = None @@ -248,8 +248,7 @@ class Dataset(_Dataset): return prom def __getitem__(self, index): - if self.training: - assert self.sampler is not None + if self.training and self.sampler is not None: path = self.sampler.sample() else: path = self.paths[index] @@ -320,7 +319,7 @@ def _create_dataloader(dataset, training): persistent_workers=True, pin_memory=False, # True, worker_init_fn=_seed_worker, - sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler + sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler ) def _load_dataset_paths():