diff --git a/vall_e/config.py b/vall_e/config.py index 615bc99..5a1ff98 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -158,6 +158,8 @@ class Dataset: max_resps: int = 1 # number of samples to target for training p_resp_append: float = 1.0 # probability to append another sample to the training target p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window + prompt_similar_top_k: int = 1 + prompt_similar_top_k_offset: int = 0 sample_type: str = "path" # path | speaker sample_order: str = "interleaved" # duration diff --git a/vall_e/data.py b/vall_e/data.py index fdb553b..b53ede5 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -880,7 +880,10 @@ class Dataset(_Dataset): return path, text, resps # icky slop - def get_similar_utterance(self, path, offset=0 ): + def get_similar_utterance(self, path, offset=None ): + if offset is None: + offset = cfg.dataset.prompt_similar_top_k_offset + reference = path.name if cfg.dataset.use_hdf5: @@ -904,7 +907,12 @@ class Dataset(_Dataset): offset = 0 metadata_keys = list(metadata.keys()) - index = reference_metadata["similar"][offset] + + if cfg.dataset.prompt_similar_top_k > 1: + indices = reference_metadata["similar"][offset:offset+cfg.dataset.prompt_similar_top_k] + index = random.choice( indices ) + else: + index = reference_metadata["similar"][offset] name = metadata_keys[index] return root / name