split sampler dict by global_rank, also handle splitting dataset paths by global_rank if sampler_type == path (because I do not trust DistributedSampler) (need to test)
This commit is contained in:
parent
31785f4eeb
commit
74df2f5332
|
@ -13,6 +13,7 @@ import itertools
|
|||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
from .utils.sampler import Sampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -229,6 +230,23 @@ class Dataset(_Dataset):
|
|||
del self.paths_by_spkr_name[key]
|
||||
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if self.sampler_type == "path" and cfg.distributed and self.training:
|
||||
batches = len(self.paths) // world_size()
|
||||
start = batches * global_rank()
|
||||
end = batches * (global_rank() + 1)
|
||||
|
||||
self.paths = self.paths[start:end]
|
||||
|
||||
# recreate paths_by_spkr_name
|
||||
self.paths_by_spkr_name = {}
|
||||
for path in self.paths:
|
||||
name = cfg.get_spkr( path )
|
||||
if name not in self.paths_by_spkr_name[name]:
|
||||
self.paths_by_spkr_name[name] = []
|
||||
self.paths_by_spkr_name[name].append( path )
|
||||
|
||||
|
||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
|
||||
|
@ -707,9 +725,11 @@ def _create_dataloader(dataset, training):
|
|||
sampler = None
|
||||
shuffle = True
|
||||
|
||||
"""
|
||||
if cfg.distributed and training:
|
||||
sampler = DistributedSampler(dataset)
|
||||
shuffle = False
|
||||
"""
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
|
@ -728,7 +748,7 @@ def create_datasets():
|
|||
train_dataset = Dataset( training=True )
|
||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||
|
||||
train_state_path = cfg.relpath / "train_dataset.pt"
|
||||
train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt"
|
||||
if train_state_path.exists():
|
||||
train_dataset.load_state_dict( train_state_path )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user