diff --git a/codes/data/audio/voice_tokenizer.py b/codes/data/audio/voice_tokenizer.py index 18188592..b1127664 100644 --- a/codes/data/audio/voice_tokenizer.py +++ b/codes/data/audio/voice_tokenizer.py @@ -29,17 +29,21 @@ def remove_extraneous_punctuation(word): class VoiceBpeTokenizer: - def __init__(self, vocab_file): + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, preprocess=None): + if preprocess is None: + self.preprocess = vocab_file[-8:] != "ipa.json" + else: + self.preprocess = preprocess if vocab_file is not None: self.tokenizer = Tokenizer.from_file(vocab_file) def preprocess_text(self, txt): txt = english_cleaners(txt) - txt = remove_extraneous_punctuation(txt) return txt def encode(self, txt): - txt = self.preprocess_text(txt) + if self.preprocess: + txt = self.preprocess_text(txt) txt = txt.replace(' ', '[SPACE]') return self.tokenizer.encode(txt).ids @@ -50,7 +54,6 @@ class VoiceBpeTokenizer: txt = txt.replace('[SPACE]', ' ') txt = txt.replace('[STOP]', '') txt = txt.replace('[UNK]', '') - return txt