fixed issue with non-distributed training, oops

This commit is contained in:
mrq 2023-08-14 21:42:35 -05:00
parent 5fa86182b5
commit 277c759ab1

View File

@ -167,7 +167,7 @@ class Dataset(_Dataset):
else: else:
self.durations[spkr_id] += duration self.durations[spkr_id] += duration
if training: if training and not cfg.distributed:
self.sampler = Sampler(self.paths, [cfg.get_spkr]) self.sampler = Sampler(self.paths, [cfg.get_spkr])
else: else:
self.sampler = None self.sampler = None
@ -248,8 +248,7 @@ class Dataset(_Dataset):
return prom return prom
def __getitem__(self, index): def __getitem__(self, index):
if self.training: if self.training and self.sampler is not None:
assert self.sampler is not None
path = self.sampler.sample() path = self.sampler.sample()
else: else:
path = self.paths[index] path = self.paths[index]
@ -320,7 +319,7 @@ def _create_dataloader(dataset, training):
persistent_workers=True, persistent_workers=True,
pin_memory=False, # True, pin_memory=False, # True,
worker_init_fn=_seed_worker, worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler
) )
def _load_dataset_paths(): def _load_dataset_paths():