lang fixes / reworked phoneme symmap validation

This commit is contained in:
mrq 2024-09-18 19:36:03 -05:00
parent 84647f588a
commit fa9d3f6c06
2 changed files with 55 additions and 9 deletions

View File

@ -411,6 +411,8 @@ def get_lang_symmap():
return { return {
"en": 0, "en": 0,
"ja": 1, "ja": 1,
"de": 2,
"fr": 3,
} }
def get_tone_symmap(): def get_tone_symmap():
@ -751,14 +753,14 @@ class Dataset(_Dataset):
res = cfg.get_spkr_group(path) res = cfg.get_spkr_group(path)
return res return res
def get_language(self, speaker_group): # this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
lang = "en" def get_language(self, speaker_group, lang="en"):
for k, v in cfg.dataset.speaker_languages.items(): for k, v in cfg.dataset.speaker_languages.items():
if speaker_group in v: if speaker_group in v:
lang = k lang = k
break break
return lang return lang.lower()
@cached_property @cached_property
def spkrs(self): def spkrs(self):
@ -877,7 +879,7 @@ class Dataset(_Dataset):
def get_similar_utterance(self, spkr_name, reference, offset=0 ): def get_similar_utterance(self, spkr_name, reference, offset=0 ):
# lots of boilerplate checks # lots of boilerplate checks
metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path = cfg.metadata_dir / f"{spkr_name}.json"
if not metadata_path.exists(): if not metadata_path.exists():
return None return None
metadata = json_read( metadata_path ) metadata = json_read( metadata_path )
@ -890,7 +892,8 @@ class Dataset(_Dataset):
offset = -1 offset = -1
metadata_keys = list(metadata.keys()) metadata_keys = list(metadata.keys())
index = reference_metadata["similar"][offset] index = reference_metadata["similar"][offset]
return metadata_keys[index] name = metadata_keys[index]
return name
def sample_prompts(self, spkr_name, reference, should_trim=True): def sample_prompts(self, spkr_name, reference, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
@ -996,8 +999,7 @@ class Dataset(_Dataset):
tone = metadata["tone"] if "tone" in metadata else None tone = metadata["tone"] if "tone" in metadata else None
text_string = metadata["text"] if "text" in metadata else None text_string = metadata["text"] if "text" in metadata else None
if not lang: lang = self.get_language(spkr_group) if not lang else lang.lower()
lang = self.get_language(spkr_group)
if not tone: if not tone:
tone = "neutral" tone = "neutral"
@ -1642,9 +1644,11 @@ if __name__ == "__main__":
missing = set() missing = set()
for i in range(len( train_dl.dataset )): dataset = train_dl.dataset
batch = train_dl.dataset[i]
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
"""
batch = train_dl.dataset[i]
text = batch['text'] text = batch['text']
phonemes = batch['metadata']['phonemes'] phonemes = batch['metadata']['phonemes']
@ -1657,6 +1661,47 @@ if __name__ == "__main__":
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
"""
if dataset.sampler_type == "group":
spkr_group = dataset.spkr_groups[index]
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
spkr_name = dataset.spkr_samplers[spkr_group].sample()
spkr_id = dataset.spkr_symmap[spkr_name]
path = dataset.samplers[spkr_name].sample()
elif dataset.sampler_type == "speaker":
spkr_name = dataset.spkrs[index]
spkr_id = dataset.spkr_symmap[spkr_name]
path = dataset.samplers[spkr_name].sample()
spkr_group = dataset.get_speaker_group(path)
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
else:
path = dataset.paths[index]
spkr_name = dataset.get_speaker(path)
spkr_id = dataset.spkr_symmap[spkr_name]
spkr_group = dataset.get_speaker_group(path)
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
continue
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
text = cfg.hdf5[key]["text"][:]
else:
_, metadata = _load_quants(path, return_metadata=True)
text = tokenize( phonemes )
phonemes = metadata["phonemes"]
for i, token in enumerate(text):
if token != "<unk>":
continue
phone = phonemes[i]
_logger.info( f"{path}: {phonemes}: {phone}" )
missing |= set([phone]) missing |= set([phone])
_logger.info( f"Missing tokens: {missing}" ) _logger.info( f"Missing tokens: {missing}" )

View File

@ -189,6 +189,7 @@ def process(
if top_k == 0: if top_k == 0:
return return
# fill any missing keys with a null embedding to keep the order the same
null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype ) null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype )
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] ) embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
sorted_similarities = {} sorted_similarities = {}