diff --git a/codes/data/audio/voice_tokenizer.py b/codes/data/audio/voice_tokenizer.py index 86fd89fa..ff7157aa 100644 --- a/codes/data/audio/voice_tokenizer.py +++ b/codes/data/audio/voice_tokenizer.py @@ -32,17 +32,30 @@ def remove_extraneous_punctuation(word): class VoiceBpeTokenizer: def __init__(self, vocab_file, preprocess=None): + with open(vocab_file, 'r', encoding='utf-8') as f: + vocab = json.load(f) + + self.language = vocab['model']['language'] if 'language' in vocab['model'] else None + if preprocess is None: - with open(vocab_file, 'r', encoding='utf-8') as f: - vocab = json.load(f) - self.preprocess = 'pre_tokenizer' in vocab and vocab['pre_tokenizer'] + self.preprocess = 'pre_tokenizer' in vocab and vocab['pre_tokenizer'] else: self.preprocess = preprocess + if vocab_file is not None: self.tokenizer = Tokenizer.from_file(vocab_file) def preprocess_text(self, txt): - txt = english_cleaners(txt) + if self.language == 'ja': + import pykakasi + + kks = pykakasi.kakasi() + results = kks.convert(txt) + txt = " ".join([ result['kana'] for result in results ]) + txt = basic_cleaners(txt) + else: + txt = english_cleaners(txt) + return txt def encode(self, txt):