From 44c08d828e344f748e5fab6a79395730d57f76eb Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 16 Aug 2023 19:39:21 -0500 Subject: [PATCH] added sample_type that samples from speakers to truly balance an epoch by speakers rather than the entire dataset and a sampler that tries to balance by speakers --- vall_e/config.py | 2 ++ vall_e/data.py | 27 ++++++++++++++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 7139e73..bfa2546 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -129,6 +129,8 @@ class Dataset: max_prompts: int = 3 prompt_duration: float = 3.0 + sample_type: str = "path" # path | speaker + @dataclass() class Model: name: str = "" diff --git a/vall_e/data.py b/vall_e/data.py index 8e476c1..f3f371b 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -119,6 +119,7 @@ class Dataset(_Dataset): max_duration=cfg.dataset.duration_range[1], training=False, extra_paths_by_spkr_name: dict[str, list] = {}, + sample_type=cfg.dataset.sample_type # path | speaker ): super().__init__() self._head = None @@ -126,6 +127,7 @@ class Dataset(_Dataset): self.max_phones = max_phones self.min_duration = min_duration self.max_duration = max_duration + self.sample_type = sample_type if cfg.dataset.validate: self.paths = [ @@ -167,7 +169,7 @@ class Dataset(_Dataset): else: self.durations[spkr_id] += duration - if training and not cfg.distributed: + if training and not cfg.distributed and self.sample_type == "path": self.sampler = Sampler(self.paths, [cfg.get_spkr]) else: self.sampler = None @@ -200,10 +202,14 @@ class Dataset(_Dataset): choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} choices = [*choices] + # no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step if len(choices) == 0: + choices = [*set(self.paths_by_spkr_name[spkr_name])] + """ raise ValueError( f"Failed to find another different utterance for {spkr_name}." ) + """ # shuffle it up a bit offset = random.randint(-16, 16) @@ -248,13 +254,17 @@ class Dataset(_Dataset): return prom def __getitem__(self, index): - if self.training and self.sampler is not None: - path = self.sampler.sample() + if hasattr(self, "sample_type") and self.sample_type == "speaker": + spkr_name = self.spkrs[index] + spkr_id = self.spkr_symmap[spkr_name] + path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) else: - path = self.paths[index] - - spkr_name = cfg.get_spkr(path) - spkr_id = self.spkr_symmap[spkr_name] + if self.training and self.sampler is not None: + path = self.sampler.sample() + else: + path = self.paths[index] + spkr_name = cfg.get_spkr(path) + spkr_id = self.spkr_symmap[spkr_name] if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) @@ -287,6 +297,8 @@ class Dataset(_Dataset): self.paths = [*_interleaved_reorder(self.paths, fn)] def __len__(self): + if hasattr(self, "sample_type") and self.sample_type == "speaker": + return min(len(self.spkrs), self._head or len(self.spkrs)) return min(len(self.paths), self._head or len(self.paths)) def pin_memory(self): @@ -444,6 +456,7 @@ def create_datasets(): def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() + #train_dataset.sample_type = "speaker" subtrain_dataset = copy.deepcopy(train_dataset) subtrain_dataset.head_(cfg.evaluation.size)