fixed issue with non-distributed training, oops
This commit is contained in:
parent
5fa86182b5
commit
277c759ab1
|
@ -167,7 +167,7 @@ class Dataset(_Dataset):
|
||||||
else:
|
else:
|
||||||
self.durations[spkr_id] += duration
|
self.durations[spkr_id] += duration
|
||||||
|
|
||||||
if training:
|
if training and not cfg.distributed:
|
||||||
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
||||||
else:
|
else:
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
@ -248,8 +248,7 @@ class Dataset(_Dataset):
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if self.training:
|
if self.training and self.sampler is not None:
|
||||||
assert self.sampler is not None
|
|
||||||
path = self.sampler.sample()
|
path = self.sampler.sample()
|
||||||
else:
|
else:
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
|
@ -320,7 +319,7 @@ def _create_dataloader(dataset, training):
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
pin_memory=False, # True,
|
pin_memory=False, # True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler
|
sampler=DistributedSampler(dataset) if cfg.distributed else None # dataset.sampler
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_dataset_paths():
|
def _load_dataset_paths():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user