actually validated and fixed sampling similar utterances for the prompt (hopefully nothing else is needed)

This commit is contained in:
mrq 2024-09-21 12:59:51 -05:00
parent d31f27119a
commit 536c11c4ac

View File

@ -879,23 +879,35 @@ class Dataset(_Dataset):
return path, text, resps
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
metadata = json_read( cfg.metadata_dir / f"{spkr_name}.json", default={} )
# icky slop
def get_similar_utterance(self, path, offset=0 ):
reference = path.name
if cfg.dataset.use_hdf5:
root = Path( *path.parts[:-1] )
path = Path( *path.parts[2:-1] )
else:
root = Path( *path.parts[:-1] )
path = Path(*path.parts[len(cfg.data_dir.parts):-1])
metadata = json_read( cfg.metadata_dir / path.with_suffix(".json"), default={} )
if reference not in metadata:
return None
reference_metadata = metadata[reference]
if "similar" not in reference_metadata:
return None
if len(reference_metadata["similar"]) >= offset:
offset = 0
metadata_keys = list(metadata.keys())
name = metadata_keys[reference_metadata["similar"][offset]]
return name
index = reference_metadata["similar"][offset]
name = metadata_keys[index]
return root / name
def sample_prompts(self, spkr_name, reference, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
@ -920,7 +932,7 @@ class Dataset(_Dataset):
for _ in range(cfg.dataset.max_prompts):
if reference is not None and cfg.dataset.prom_sample_similar:
path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) )
path = self.get_similar_utterance( reference, offset = len(prom_list) )
# yuck
if not path:
path = random.choice(choices)