From 536c11c4aca09b5e03dfbc0add6d9595f3aea198 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 21 Sep 2024 12:59:51 -0500 Subject: [PATCH] actually validated and fixed sampling similar utterances for the prompt (hopefully nothing else is needed) --- vall_e/data.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 6bea50d..960e1d8 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -879,23 +879,35 @@ class Dataset(_Dataset): return path, text, resps - def get_similar_utterance(self, spkr_name, reference, offset=0 ): - metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} ) - + # icky slop + def get_similar_utterance(self, path, offset=0 ): + reference = path.name + + if cfg.dataset.use_hdf5: + root = Path( *path.parts[:-1] ) + path = Path( *path.parts[2:-1] ) + else: + root = Path( *path.parts[:-1] ) + path = Path(*path.parts[len(cfg.data_dir.parts):-1]) + + metadata = json_read( cfg.metadata_dir / path.with_suffix(".json"), default={} ) + if reference not in metadata: return None reference_metadata = metadata[reference] - + if "similar" not in reference_metadata: return None - + if len(reference_metadata["similar"]) >= offset: offset = 0 - + metadata_keys = list(metadata.keys()) - name = metadata_keys[reference_metadata["similar"][offset]] - return name + index = reference_metadata["similar"][offset] + name = metadata_keys[index] + + return root / name def sample_prompts(self, spkr_name, reference, should_trim=True): if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: @@ -920,7 +932,7 @@ class Dataset(_Dataset): for _ in range(cfg.dataset.max_prompts): if reference is not None and cfg.dataset.prom_sample_similar: - path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) ) + path = self.get_similar_utterance( reference, offset = len(prom_list) ) # yuck if not path: path = random.choice(choices)