fixed issue with non-distributed training, oops

This commit is contained in:
mrq 2023-08-14 21:42:35 -05:00
parent 5fa86182b5
commit 277c759ab1

View File

@ -167,7 +167,7 @@ class Dataset(_Dataset):
else:
self.durations[spkr_id] += duration
if training:
if training and not cfg.distributed:
self.sampler = Sampler(self.paths, [cfg.get_spkr])
else:
self.sampler = None
@ -248,8 +248,7 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
if self.training:
assert self.sampler is not None
if self.training and self.sampler is not None:
path = self.sampler.sample()
else:
path = self.paths[index]
@ -320,7 +319,7 @@ def _create_dataloader(dataset, training):
persistent_workers=True,
pin_memory=False, # True,
worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler
sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler
)
def _load_dataset_paths():