From 4a71981456081e09220b7bd7b6206f4eab0dcb16 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 18 Nov 2024 12:46:50 -0600 Subject: [PATCH] normalize sampler index by batch size (if not using batched sampler), add option to cap out utterances for a speaker, some other things --- vall_e/config.py | 1 + vall_e/data.py | 35 ++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 0f56c9b..85eb89d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -171,6 +171,7 @@ class Dataset: cache: bool = True # use diskcache to cache the dataset min_utterances: int = 2 # minimum number of utterances a speaker can have + max_utterances: int = 0 # max number of utterances a speaker can have (0 to disable) duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset sample_type: str = "path" # path | speaker diff --git a/vall_e/data.py b/vall_e/data.py index 9a7f644..239ee81 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -689,6 +689,11 @@ class Dataset(_Dataset): self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset)) + + self.duration = 0 + self.duration_buckets = {} + self.current_index = 0 + self.batch_size = cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size # to-do: do not do validation if there's nothing in the validation # this just makes it be happy @@ -701,31 +706,24 @@ class Dataset(_Dataset): # dict of paths keyed by speaker names self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) - # do it here due to the above - self.duration = 0 self.duration_map = _get_duration_map( self.dataset_type ) - self.duration_buckets = {} - # cull speakers if they do not have enough utterances - if cfg.dataset.min_utterances > 0: + # cull speakers if they do not have enough utterances (or cull speakers with too many utternaces) + if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0: keys = list(self.paths_by_spkr_name.keys()) for key in keys: if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances: del self.paths_by_spkr_name[key] + # slice away extraneous utterances + if cfg.dataset.max_utterances: + self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances] + # flatten paths self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) # split dataset accordingly per GPU if cfg.distributed and self.training: - """ - batches = len(self.paths) // world_size() - start = batches * global_rank() - end = batches * (global_rank() + 1) - - self.paths = self.paths[start:end] - """ - self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] # recreate paths_by_spkr_name @@ -819,9 +817,10 @@ class Dataset(_Dataset): self.sampler = BatchedOrderedSampler( self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways max_duration=cfg.dataset.sample_max_duration_batch, - max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size, + max_batch_size=self.batch_size, shuffle=self.sampler_shuffle, ) + self.batch_size = 1 else: self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) self.samplers = {} @@ -1084,6 +1083,8 @@ class Dataset(_Dataset): return prom def __getitem__(self, index): + self.current_index = index + if self.empty_text is None: self.empty_text = tokenize(" ") @@ -1385,7 +1386,7 @@ class Dataset(_Dataset): self.training = value def index(self): - return self.sampler.index() if self.sampler is not None else -1 + return (self.sampler.index() if self.sampler is not None else -1) // self.batch_size def __len__(self): if self.sampler_type == "group": @@ -1471,8 +1472,8 @@ def create_train_val_dataloader(): val_dl = _create_dataloader(val_dataset, training=False) _logger.info(str(train_dataset.phone_symmap)) - _logger.info(str(train_dataset.spkr_symmap)) - _logger.info(str(train_dataset.spkr_group_symmap)) + _logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}') + _logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}') _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.")