vall-e/vall_e/emb/g2p.py

83 lines
2.0 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import argparse
import random
import string
import torch
from functools import cache
from pathlib import Path
from phonemizer import phonemize
from phonemizer.backend import BACKENDS
from tqdm import tqdm
@cache
def _get_graphs(path):
with open(path, "r") as f:
graphs = f.read()
return graphs
cached_backends = {}
def _get_backend( language="en-us", backend="espeak" ):
key = f'{language}_{backend}'
if key in cached_backends:
return cached_backends[key]
if backend == 'espeak':
phonemizer = BACKENDS[backend]( language, preserve_punctuation=True, with_stress=True)
elif backend == 'espeak-mbrola':
phonemizer = BACKENDS[backend]( language )
else:
phonemizer = BACKENDS[backend]( language, preserve_punctuation=True )
cached_backends[key] = phonemizer
return phonemizer
def encode(text: str, language="en-us", backend="auto") -> list[str]:
2023-08-02 21:53:35 +00:00
if language == "en":
language = "en-us"
if not backend or backend == "auto":
backend = "espeak" # if language[:2] != "en" else "festival"
2023-08-02 21:53:35 +00:00
text = [ text ]
backend = _get_backend(language=language, backend=backend)
if backend is not None:
tokens = backend.phonemize( text, strip=True )
else:
tokens = phonemize( text, language=language, strip=True, preserve_punctuation=True, with_stress=True )
tokens = list(tokens[0])
return tokens
"""
2023-08-02 21:53:35 +00:00
tokenized = " ".join( tokens )
merges = [ "\u02C8", "\u02CC", "\u02D0" ]
for merge in merges:
tokenized = tokenized.replace( f' {merge}', merge )
return tokenized.split(" ")
"""
2023-08-02 21:53:35 +00:00
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("folder", type=Path)
2023-08-04 01:36:19 +00:00
parser.add_argument("--suffix", type=str, default=".txt")
2023-08-02 21:53:35 +00:00
args = parser.parse_args()
paths = list(args.folder.rglob(f"*{args.suffix}"))
for path in tqdm(paths):
phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
if phone_path.exists():
continue
phones = encode(open(path, "r", encoding="utf-8").read())
open(phone_path, "w", encoding="utf-8").write(" ".join(phones))
2023-08-02 21:53:35 +00:00
if __name__ == "__main__":
main()