add top_k sampling / offset for prompt similar utterance sampling

This commit is contained in:
mrq 2024-09-26 16:26:40 -05:00
parent 9da630f73a
commit f24547ad4e
2 changed files with 12 additions and 2 deletions

View File

@ -158,6 +158,8 @@ class Dataset:
max_resps: int = 1 # number of samples to target for training max_resps: int = 1 # number of samples to target for training
p_resp_append: float = 1.0 # probability to append another sample to the training target p_resp_append: float = 1.0 # probability to append another sample to the training target
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
prompt_similar_top_k: int = 1
prompt_similar_top_k_offset: int = 0
sample_type: str = "path" # path | speaker sample_type: str = "path" # path | speaker
sample_order: str = "interleaved" # duration sample_order: str = "interleaved" # duration

View File

@ -880,7 +880,10 @@ class Dataset(_Dataset):
return path, text, resps return path, text, resps
# icky slop # icky slop
def get_similar_utterance(self, path, offset=0 ): def get_similar_utterance(self, path, offset=None ):
if offset is None:
offset = cfg.dataset.prompt_similar_top_k_offset
reference = path.name reference = path.name
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
@ -904,6 +907,11 @@ class Dataset(_Dataset):
offset = 0 offset = 0
metadata_keys = list(metadata.keys()) metadata_keys = list(metadata.keys())
if cfg.dataset.prompt_similar_top_k > 1:
indices = reference_metadata["similar"][offset:offset+cfg.dataset.prompt_similar_top_k]
index = random.choice( indices )
else:
index = reference_metadata["similar"][offset] index = reference_metadata["similar"][offset]
name = metadata_keys[index] name = metadata_keys[index]