From 74df2f5332b618624eb24e8c7388553ffa9f5f0a Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 1 Jun 2024 09:29:49 -0500 Subject: [PATCH] split sampler dict by global_rank, also handle splitting dataset paths by global_rank if sampler_type == path (because I do not trust DistributedSampler) (need to test) --- vall_e/data.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vall_e/data.py b/vall_e/data.py index 47b7332..3bea75b 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -13,6 +13,7 @@ import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file from .utils.sampler import Sampler +from .utils.distributed import global_rank, local_rank, world_size from collections import defaultdict from functools import cache, cached_property @@ -229,6 +230,23 @@ class Dataset(_Dataset): del self.paths_by_spkr_name[key] self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) + + # split dataset accordingly per GPU + if self.sampler_type == "path" and cfg.distributed and self.training: + batches = len(self.paths) // world_size() + start = batches * global_rank() + end = batches * (global_rank() + 1) + + self.paths = self.paths[start:end] + + # recreate paths_by_spkr_name + self.paths_by_spkr_name = {} + for path in self.paths: + name = cfg.get_spkr( path ) + if name not in self.paths_by_spkr_name[name]: + self.paths_by_spkr_name[name] = [] + self.paths_by_spkr_name[name].append( path ) + self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } @@ -707,9 +725,11 @@ def _create_dataloader(dataset, training): sampler = None shuffle = True + """ if cfg.distributed and training: sampler = DistributedSampler(dataset) shuffle = False + """ return DataLoader( dataset=dataset, @@ -728,7 +748,7 @@ def create_datasets(): train_dataset = Dataset( training=True ) val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False ) - train_state_path = cfg.relpath / "train_dataset.pt" + train_state_path = cfg.relpath / f"train_dataset.{global_rank()}.pt" if train_state_path.exists(): train_dataset.load_state_dict( train_state_path )