lang fixes / reworked phoneme symmap validation
This commit is contained in:
parent
84647f588a
commit
fa9d3f6c06
|
@ -411,6 +411,8 @@ def get_lang_symmap():
|
|||
return {
|
||||
"en": 0,
|
||||
"ja": 1,
|
||||
"de": 2,
|
||||
"fr": 3,
|
||||
}
|
||||
|
||||
def get_tone_symmap():
|
||||
|
@ -751,14 +753,14 @@ class Dataset(_Dataset):
|
|||
res = cfg.get_spkr_group(path)
|
||||
return res
|
||||
|
||||
def get_language(self, speaker_group):
|
||||
lang = "en"
|
||||
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
|
||||
def get_language(self, speaker_group, lang="en"):
|
||||
for k, v in cfg.dataset.speaker_languages.items():
|
||||
if speaker_group in v:
|
||||
lang = k
|
||||
break
|
||||
|
||||
return lang
|
||||
return lang.lower()
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
|
@ -877,7 +879,7 @@ class Dataset(_Dataset):
|
|||
|
||||
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
||||
# lots of boilerplate checks
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
metadata_path = cfg.metadata_dir / f"{spkr_name}.json"
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
metadata = json_read( metadata_path )
|
||||
|
@ -890,7 +892,8 @@ class Dataset(_Dataset):
|
|||
offset = -1
|
||||
metadata_keys = list(metadata.keys())
|
||||
index = reference_metadata["similar"][offset]
|
||||
return metadata_keys[index]
|
||||
name = metadata_keys[index]
|
||||
return name
|
||||
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||
|
@ -996,8 +999,7 @@ class Dataset(_Dataset):
|
|||
tone = metadata["tone"] if "tone" in metadata else None
|
||||
text_string = metadata["text"] if "text" in metadata else None
|
||||
|
||||
if not lang:
|
||||
lang = self.get_language(spkr_group)
|
||||
lang = self.get_language(spkr_group) if not lang else lang.lower()
|
||||
|
||||
if not tone:
|
||||
tone = "neutral"
|
||||
|
@ -1642,9 +1644,11 @@ if __name__ == "__main__":
|
|||
|
||||
missing = set()
|
||||
|
||||
for i in range(len( train_dl.dataset )):
|
||||
batch = train_dl.dataset[i]
|
||||
dataset = train_dl.dataset
|
||||
|
||||
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
||||
"""
|
||||
batch = train_dl.dataset[i]
|
||||
text = batch['text']
|
||||
phonemes = batch['metadata']['phonemes']
|
||||
|
||||
|
@ -1657,6 +1661,47 @@ if __name__ == "__main__":
|
|||
|
||||
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
||||
|
||||
missing |= set([phone])
|
||||
"""
|
||||
|
||||
if dataset.sampler_type == "group":
|
||||
spkr_group = dataset.spkr_groups[index]
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
spkr_name = dataset.spkr_samplers[spkr_group].sample()
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
path = dataset.samplers[spkr_name].sample()
|
||||
elif dataset.sampler_type == "speaker":
|
||||
spkr_name = dataset.spkrs[index]
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
path = dataset.samplers[spkr_name].sample()
|
||||
spkr_group = dataset.get_speaker_group(path)
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
else:
|
||||
path = dataset.paths[index]
|
||||
spkr_name = dataset.get_speaker(path)
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
spkr_group = dataset.get_speaker_group(path)
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
if key not in cfg.hdf5:
|
||||
continue
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
text = tokenize( phonemes )
|
||||
phonemes = metadata["phonemes"]
|
||||
|
||||
for i, token in enumerate(text):
|
||||
if token != "<unk>":
|
||||
continue
|
||||
|
||||
phone = phonemes[i]
|
||||
|
||||
_logger.info( f"{path}: {phonemes}: {phone}" )
|
||||
|
||||
missing |= set([phone])
|
||||
|
||||
_logger.info( f"Missing tokens: {missing}" )
|
||||
|
|
|
@ -189,6 +189,7 @@ def process(
|
|||
if top_k == 0:
|
||||
return
|
||||
|
||||
# fill any missing keys with a null embedding to keep the order the same
|
||||
null_embedding = torch.zeros( (1024,), device=tts.device, dtype=tts.dtype )
|
||||
embeddings = torch.stack( [ feature if feature is not None else null_embedding for feature in features.values() ] )
|
||||
sorted_similarities = {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user