some insanity for sanity checks (some phonemes from phonemizing japanese are not in my tokenizer...)

This commit is contained in:
mrq 2024-07-22 00:30:40 -05:00
parent ad024f400f
commit 491ae2a684
3 changed files with 48 additions and 10 deletions

View File

@ -115,6 +115,9 @@
"</eos>": 2,
"<mask>": 3,
" ": 4,
"ᵝ": 0,
"!": 5,
"\"": 6,
"(": 7,
@ -127,6 +130,7 @@
";": 14,
"?": 15,
"a": 16,
"ä": 16,
"b": 17,
"c": 18,
"d": 19,
@ -134,20 +138,27 @@
"f": 21,
"h": 22,
"i": 23,
"ĩ": 23,
"j": 24,
"k": 25,
"l": 26,
"m": 27,
"n": 28,
"ɴ": 28,
"o": 29,
"̞": 29,
"p": 30,
"ɸ": 30,
"q": 31,
"r": 32,
"ɽ": 32,
"s": 33,
"t": 34,
"u": 35,
"ũ": 35,
"v": 36,
"w": 37,
"ʍ": 37,
"x": 38,
"z": 39,
"¡": 40,
@ -184,6 +195,7 @@
"ʲ": 71,
"ˈ": 72,
"ˌ": 73,
"ˌ": 73,
"ː": 74,
"̃": 75,
"̩": 76,

View File

@ -1505,7 +1505,7 @@ if __name__ == "__main__":
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
@ -1529,6 +1529,30 @@ if __name__ == "__main__":
for k, v in samples.items():
for i in range(len(v)):
print(f'{k}[{i}]:', v[i])
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
missing = set()
for i in range(len( train_dl.dataset )):
batch = train_dl.dataset[i]
text = batch['text']
phonemes = batch['metadata']['phonemes']
decoded = [ cfg.tokenizer.decode(token) for token in text[1:-1] ]
for i, token in enumerate(decoded):
if token != "<unk>":
continue
phone = phonemes[i]
print( batch['text'], batch['metadata']['phonemes'] )
missing |= set([phone])
print( "Missing tokens:", missing )
elif args.action == "tasks":
index = 0

View File

@ -27,23 +27,23 @@ def romanize( runes, sep="" ):
return sep.join([ res['hira'] for res in result ])
cached_backends = {}
def _get_backend( language="en-us", backend="espeak" ):
def _get_backend( language="en-us", backend="espeak", punctuation=True, stress=True, strip=True ):
key = f'{language}_{backend}'
if key in cached_backends:
return cached_backends[key]
if backend == 'espeak':
phonemizer = BACKENDS[backend]( language, preserve_punctuation=True, with_stress=True)
phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation, with_stress=stress)
elif backend == 'espeak-mbrola':
phonemizer = BACKENDS[backend]( language )
else:
phonemizer = BACKENDS[backend]( language, preserve_punctuation=True )
phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation )
cached_backends[key] = phonemizer
return phonemizer
def encode(text: str, language="en-us", backend="auto") -> list[str]:
def encode(text: str, language="en-us", backend="auto", punctuation=True, stress=True, strip=True) -> list[str]:
if language == "en":
language = "en-us"
@ -56,13 +56,15 @@ def encode(text: str, language="en-us", backend="auto") -> list[str]:
text = [ text ]
backend = _get_backend(language=language, backend=backend)
backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation)
if backend is not None:
tokens = backend.phonemize( text, strip=True )
tokens = backend.phonemize( text, strip=strip )
else:
tokens = phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True )
tokens = phonemize( text, language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress )
tokens = list(tokens[0])
if not len(tokens):
tokens = []
else:
tokens = list(tokens[0])
return tokens