From b0795baa43890b54838de1f33e72b0bf4d400a8a Mon Sep 17 00:00:00 2001 From: lightmare Date: Thu, 23 Feb 2023 10:16:28 +0000 Subject: [PATCH] Added `token_timestamps` and `print_timestamps` params --- whispercpp.pyx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/whispercpp.pyx b/whispercpp.pyx index 2abaae3..cc9d153 100644 --- a/whispercpp.pyx +++ b/whispercpp.pyx @@ -19,6 +19,8 @@ cdef char* LANGUAGE = b'en' cdef int N_THREADS = os.cpu_count() cdef _Bool PRINT_REALTIME = False cdef _Bool PRINT_PROGRESS = False +cdef _Bool TOKEN_TIMESTAMPS = False +cdef _Bool PRINT_TIMESTAMPS = True cdef _Bool TRANSLATE = False @@ -85,12 +87,14 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr return frames -cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool translate, char* language, int n_threads) nogil: +cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil: cdef whisper_full_params params = whisper_full_default_params( whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY ) params.print_realtime = print_realtime params.print_progress = print_progress + params.token_timestamps = token_timestamps + params.print_timestamps = print_timestamps params.translate = translate params.language = language params.n_threads = n_threads @@ -100,7 +104,7 @@ cdef class Whisper: cdef whisper_context * ctx cdef whisper_full_params params - def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs? + def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs? """Constructor for Whisper class. Automatically checks for model and downloads it if necessary. @@ -110,6 +114,8 @@ cdef class Whisper: models_dir: The path where the models should be stored print_realtime: whisper.cpp's real time transcription output print_progress: whisper.cpp's progress indicator + token_timestamps: output timestamps for tokens + print_timestamps: whisper.cpp's timestamped output translate: whisper.cpp's translation option language: Which language to use. Must be a byte string. n_threads: Amount of threads to use @@ -120,7 +126,7 @@ cdef class Whisper: model_path = Path(models_dir).joinpath(model_fullname) cdef bytes model_b = str(model_path).encode('utf8') self.ctx = whisper_init(model_b) - self.params = set_params(print_realtime, print_progress, translate, language, n_threads) + self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads) if print_system_info: whisper_print_system_info()