split sampler dict by global_rank, also handle splitting dataset paths by global_rank if sampler_type == path (because I do not trust DistributedSampler) (need to test)
This commit is contained in:
parent
31785f4eeb
commit
74df2f5332
|
@ -13,6 +13,7 @@ import itertools
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
from .utils.sampler import Sampler
|
from .utils.sampler import Sampler
|
||||||
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -229,6 +230,23 @@ class Dataset(_Dataset):
|
||||||
del self.paths_by_spkr_name[key]
|
del self.paths_by_spkr_name[key]
|
||||||
|
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
|
# split dataset accordingly per GPU
|
||||||
|
if self.sampler_type == "path" and cfg.distributed and self.training:
|
||||||
|
batches = len(self.paths) // world_size()
|
||||||
|
start = batches * global_rank()
|
||||||
|
end = batches * (global_rank() + 1)
|
||||||
|
|
||||||
|
self.paths = self.paths[start:end]
|
||||||
|
|
||||||
|
# recreate paths_by_spkr_name
|
||||||
|
self.paths_by_spkr_name = {}
|
||||||
|
for path in self.paths:
|
||||||
|
name = cfg.get_spkr( path )
|
||||||
|
if name not in self.paths_by_spkr_name[name]:
|
||||||
|
self.paths_by_spkr_name[name] = []
|
||||||
|
self.paths_by_spkr_name[name].append( path )
|
||||||
|
|
||||||
|
|
||||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
|
|
||||||
|
@ -707,9 +725,11 @@ def _create_dataloader(dataset, training):
|
||||||
sampler = None
|
sampler = None
|
||||||
shuffle = True
|
shuffle = True
|
||||||
|
|
||||||
|
"""
|
||||||
if cfg.distributed and training:
|
if cfg.distributed and training:
|
||||||
sampler = DistributedSampler(dataset)
|
sampler = DistributedSampler(dataset)
|
||||||
shuffle = False
|
shuffle = False
|
||||||
|
"""
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -728,7 +748,7 @@ def create_datasets():
|
||||||
train_dataset = Dataset( training=True )
|
train_dataset = Dataset( training=True )
|
||||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
|
|
||||||
train_state_path = cfg.relpath / "train_dataset.pt"
|
train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt"
|
||||||
if train_state_path.exists():
|
if train_state_path.exists():
|
||||||
train_dataset.load_state_dict( train_state_path )
|
train_dataset.load_state_dict( train_state_path )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user