Added token_timestamps and print_timestamps params

This commit is contained in:
lightmare 2023-02-23 10:16:28 +00:00
parent 12340d769f
commit b0795baa43

View File

@ -19,6 +19,8 @@ cdef char* LANGUAGE = b'en'
cdef int N_THREADS = os.cpu_count()
cdef _Bool PRINT_REALTIME = False
cdef _Bool PRINT_PROGRESS = False
cdef _Bool TOKEN_TIMESTAMPS = False
cdef _Bool PRINT_TIMESTAMPS = True
cdef _Bool TRANSLATE = False
@ -85,12 +87,14 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
return frames
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool translate, char* language, int n_threads) nogil:
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil:
cdef whisper_full_params params = whisper_full_default_params(
whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
)
params.print_realtime = print_realtime
params.print_progress = print_progress
params.token_timestamps = token_timestamps
params.print_timestamps = print_timestamps
params.translate = translate
params.language = <const char *> language
params.n_threads = n_threads
@ -100,7 +104,7 @@ cdef class Whisper:
cdef whisper_context * ctx
cdef whisper_full_params params
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
"""Constructor for Whisper class.
Automatically checks for model and downloads it if necessary.
@ -110,6 +114,8 @@ cdef class Whisper:
models_dir: The path where the models should be stored
print_realtime: whisper.cpp's real time transcription output
print_progress: whisper.cpp's progress indicator
token_timestamps: output timestamps for tokens
print_timestamps: whisper.cpp's timestamped output
translate: whisper.cpp's translation option
language: Which language to use. Must be a byte string.
n_threads: Amount of threads to use
@ -120,7 +126,7 @@ cdef class Whisper:
model_path = Path(models_dir).joinpath(model_fullname)
cdef bytes model_b = str(model_path).encode('utf8')
self.ctx = whisper_init(model_b)
self.params = set_params(print_realtime, print_progress, translate, language, n_threads)
self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads)
if print_system_info:
whisper_print_system_info()