Compare commits

...

2 Commits

Author SHA1 Message Date
lightmare
5a4f948785 Added list_languages function 2023-02-23 21:39:05 +00:00
lightmare
b0795baa43 Added token_timestamps and print_timestamps params 2023-02-23 10:16:28 +00:00
2 changed files with 27 additions and 3 deletions

View File

@ -80,6 +80,8 @@ cdef extern from "whisper.h" nogil:
cdef whisper_token_data whisper_sample_best(whisper_context*)
cdef whisper_token whisper_sample_timestamp(whisper_context*)
cdef int whisper_lang_id(char*)
cdef int whisper_lang_max_id()
const char* whisper_lang_str(int)
cdef int whisper_n_len(whisper_context*)
cdef int whisper_n_vocab(whisper_context*)
cdef int whisper_n_text_ctx(whisper_context*)

View File

@ -19,6 +19,8 @@ cdef char* LANGUAGE = b'en'
cdef int N_THREADS = os.cpu_count()
cdef _Bool PRINT_REALTIME = False
cdef _Bool PRINT_PROGRESS = False
cdef _Bool TOKEN_TIMESTAMPS = False
cdef _Bool PRINT_TIMESTAMPS = True
cdef _Bool TRANSLATE = False
@ -59,6 +61,22 @@ def download_model(model, models_dir=MODELS_DIR):
f.write(r.read())
def list_languages():
"""Returns a list of tuples of language codes understood by whisper.cpp.
Returns:
e.g. [(0, "en"), (1, "zh"), ...]
"""
cdef int max_id = whisper_lang_max_id() + 1
cdef list results = []
for i in range(max_id):
results.append((
i,
whisper_lang_str(i).decode()
))
return results
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE):
try:
out = (
@ -85,12 +103,14 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
return frames
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool translate, char* language, int n_threads) nogil:
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil:
cdef whisper_full_params params = whisper_full_default_params(
whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
)
params.print_realtime = print_realtime
params.print_progress = print_progress
params.token_timestamps = token_timestamps
params.print_timestamps = print_timestamps
params.translate = translate
params.language = <const char *> language
params.n_threads = n_threads
@ -100,7 +120,7 @@ cdef class Whisper:
cdef whisper_context * ctx
cdef whisper_full_params params
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
"""Constructor for Whisper class.
Automatically checks for model and downloads it if necessary.
@ -110,6 +130,8 @@ cdef class Whisper:
models_dir: The path where the models should be stored
print_realtime: whisper.cpp's real time transcription output
print_progress: whisper.cpp's progress indicator
token_timestamps: output timestamps for tokens
print_timestamps: whisper.cpp's timestamped output
translate: whisper.cpp's translation option
language: Which language to use. Must be a byte string.
n_threads: Amount of threads to use
@ -120,7 +142,7 @@ cdef class Whisper:
model_path = Path(models_dir).joinpath(model_fullname)
cdef bytes model_b = str(model_path).encode('utf8')
self.ctx = whisper_init(model_b)
self.params = set_params(print_realtime, print_progress, translate, language, n_threads)
self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads)
if print_system_info:
whisper_print_system_info()