From fa9d3f6c06d4a12fd840d965480af09be0c7b25c Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 18 Sep 2024 19:36:03 -0500 Subject: [PATCH] lang fixes / reworked phoneme symmap validation --- vall_e/data.py | 63 ++++++++++++++++++++++++++++++++++++------- vall_e/emb/similar.py | 1 + 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 975505c..f0b0904 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -411,6 +411,8 @@ def get_lang_symmap(): return { "en": 0, "ja": 1, + "de": 2, + "fr": 3, } def get_tone_symmap(): @@ -751,14 +753,14 @@ class Dataset(_Dataset): res = cfg.get_spkr_group(path) return res - def get_language(self, speaker_group): - lang = "en" + # this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times) + def get_language(self, speaker_group, lang="en"): for k, v in cfg.dataset.speaker_languages.items(): if speaker_group in v: lang = k break - return lang + return lang.lower() @cached_property def spkrs(self): @@ -877,7 +879,7 @@ class Dataset(_Dataset): def get_similar_utterance(self, spkr_name, reference, offset=0 ): # lots of boilerplate checks - metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + metadata_path = cfg.metadata_dir / f"{spkr_name}.json" if not metadata_path.exists(): return None metadata = json_read( metadata_path ) @@ -890,7 +892,8 @@ class Dataset(_Dataset): offset = -1 metadata_keys = list(metadata.keys()) index = reference_metadata["similar"][offset] - return metadata_keys[index] + name = metadata_keys[index] + return name def sample_prompts(self, spkr_name, reference, should_trim=True): if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: @@ -996,8 +999,7 @@ class Dataset(_Dataset): tone = metadata["tone"] if "tone" in metadata else None text_string = metadata["text"] if "text" in metadata else None - if not lang: - lang = self.get_language(spkr_group) + lang = self.get_language(spkr_group) if not lang else lang.lower() if not tone: tone = "neutral" @@ -1642,9 +1644,11 @@ if __name__ == "__main__": missing = set() - for i in range(len( train_dl.dataset )): - batch = train_dl.dataset[i] + dataset = train_dl.dataset + for index in tqdm(range(len( dataset )), desc="Processing dataset..."): + """ + batch = train_dl.dataset[i] text = batch['text'] phonemes = batch['metadata']['phonemes'] @@ -1657,6 +1661,47 @@ if __name__ == "__main__": _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) + missing |= set([phone]) + """ + + if dataset.sampler_type == "group": + spkr_group = dataset.spkr_groups[index] + #spkr_group_id = dataset.spkr_group_symmap[spkr_group] + spkr_name = dataset.spkr_samplers[spkr_group].sample() + spkr_id = dataset.spkr_symmap[spkr_name] + path = dataset.samplers[spkr_name].sample() + elif dataset.sampler_type == "speaker": + spkr_name = dataset.spkrs[index] + spkr_id = dataset.spkr_symmap[spkr_name] + path = dataset.samplers[spkr_name].sample() + spkr_group = dataset.get_speaker_group(path) + #spkr_group_id = dataset.spkr_group_symmap[spkr_group] + else: + path = dataset.paths[index] + spkr_name = dataset.get_speaker(path) + spkr_id = dataset.spkr_symmap[spkr_name] + spkr_group = dataset.get_speaker_group(path) + #spkr_group_id = dataset.spkr_group_symmap[spkr_group] + + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) + if key not in cfg.hdf5: + continue + metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } + text = cfg.hdf5[key]["text"][:] + else: + _, metadata = _load_quants(path, return_metadata=True) + text = tokenize( phonemes ) + phonemes = metadata["phonemes"] + + for i, token in enumerate(text): + if token != "": + continue + + phone = phonemes[i] + + _logger.info( f"{path}: {phonemes}: {phone}" ) + missing |= set([phone]) _logger.info( f"Missing tokens: {missing}" ) diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 4453f48..3992dae 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -189,6 +189,7 @@ def process( if top_k == 0: return + # fill any missing keys with a null embedding to keep the order the same null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype ) embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] ) sorted_similarities = {}