split sampler dict by global_rank, also handle splitting dataset paths by global_rank if sampler_type == path (because I do not trust DistributedSampler) (need to test)

This commit is contained in:
mrq 2024-06-01 09:29:49 -05:00
parent 31785f4eeb
commit 74df2f5332

View File

@ -13,6 +13,7 @@ import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
from .utils.sampler import Sampler
from .utils.distributed import global_rank, local_rank, world_size
from collections import defaultdict
from functools import cache, cached_property
@ -229,6 +230,23 @@ class Dataset(_Dataset):
del self.paths_by_spkr_name[key]
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
# split dataset accordingly per GPU
if self.sampler_type == "path" and cfg.distributed and self.training:
batches = len(self.paths) // world_size()
start = batches * global_rank()
end = batches * (global_rank() + 1)
self.paths = self.paths[start:end]
# recreate paths_by_spkr_name
self.paths_by_spkr_name = {}
for path in self.paths:
name = cfg.get_spkr( path )
if name not in self.paths_by_spkr_name[name]:
self.paths_by_spkr_name[name] = []
self.paths_by_spkr_name[name].append( path )
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
@ -707,9 +725,11 @@ def _create_dataloader(dataset, training):
sampler = None
shuffle = True
"""
if cfg.distributed and training:
sampler = DistributedSampler(dataset)
shuffle = False
"""
return DataLoader(
dataset=dataset,
@ -728,7 +748,7 @@ def create_datasets():
train_dataset = Dataset( training=True )
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
train_state_path = cfg.relpath / "train_dataset.pt"
train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt"
if train_state_path.exists():
train_dataset.load_state_dict( train_state_path )