Compare commits
2 Commits
12340d769f
...
5a4f948785
Author | SHA1 | Date | |
---|---|---|---|
|
5a4f948785 | ||
|
b0795baa43 |
|
@ -80,6 +80,8 @@ cdef extern from "whisper.h" nogil:
|
|||
cdef whisper_token_data whisper_sample_best(whisper_context*)
|
||||
cdef whisper_token whisper_sample_timestamp(whisper_context*)
|
||||
cdef int whisper_lang_id(char*)
|
||||
cdef int whisper_lang_max_id()
|
||||
const char* whisper_lang_str(int)
|
||||
cdef int whisper_n_len(whisper_context*)
|
||||
cdef int whisper_n_vocab(whisper_context*)
|
||||
cdef int whisper_n_text_ctx(whisper_context*)
|
||||
|
|
|
@ -19,6 +19,8 @@ cdef char* LANGUAGE = b'en'
|
|||
cdef int N_THREADS = os.cpu_count()
|
||||
cdef _Bool PRINT_REALTIME = False
|
||||
cdef _Bool PRINT_PROGRESS = False
|
||||
cdef _Bool TOKEN_TIMESTAMPS = False
|
||||
cdef _Bool PRINT_TIMESTAMPS = True
|
||||
cdef _Bool TRANSLATE = False
|
||||
|
||||
|
||||
|
@ -59,6 +61,22 @@ def download_model(model, models_dir=MODELS_DIR):
|
|||
f.write(r.read())
|
||||
|
||||
|
||||
def list_languages():
|
||||
"""Returns a list of tuples of language codes understood by whisper.cpp.
|
||||
|
||||
Returns:
|
||||
e.g. [(0, "en"), (1, "zh"), ...]
|
||||
"""
|
||||
cdef int max_id = whisper_lang_max_id() + 1
|
||||
cdef list results = []
|
||||
for i in range(max_id):
|
||||
results.append((
|
||||
i,
|
||||
whisper_lang_str(i).decode()
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE):
|
||||
try:
|
||||
out = (
|
||||
|
@ -85,12 +103,14 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
|
|||
|
||||
return frames
|
||||
|
||||
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool translate, char* language, int n_threads) nogil:
|
||||
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil:
|
||||
cdef whisper_full_params params = whisper_full_default_params(
|
||||
whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
|
||||
)
|
||||
params.print_realtime = print_realtime
|
||||
params.print_progress = print_progress
|
||||
params.token_timestamps = token_timestamps
|
||||
params.print_timestamps = print_timestamps
|
||||
params.translate = translate
|
||||
params.language = <const char *> language
|
||||
params.n_threads = n_threads
|
||||
|
@ -100,7 +120,7 @@ cdef class Whisper:
|
|||
cdef whisper_context * ctx
|
||||
cdef whisper_full_params params
|
||||
|
||||
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
|
||||
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
|
||||
"""Constructor for Whisper class.
|
||||
|
||||
Automatically checks for model and downloads it if necessary.
|
||||
|
@ -110,6 +130,8 @@ cdef class Whisper:
|
|||
models_dir: The path where the models should be stored
|
||||
print_realtime: whisper.cpp's real time transcription output
|
||||
print_progress: whisper.cpp's progress indicator
|
||||
token_timestamps: output timestamps for tokens
|
||||
print_timestamps: whisper.cpp's timestamped output
|
||||
translate: whisper.cpp's translation option
|
||||
language: Which language to use. Must be a byte string.
|
||||
n_threads: Amount of threads to use
|
||||
|
@ -120,7 +142,7 @@ cdef class Whisper:
|
|||
model_path = Path(models_dir).joinpath(model_fullname)
|
||||
cdef bytes model_b = str(model_path).encode('utf8')
|
||||
self.ctx = whisper_init(model_b)
|
||||
self.params = set_params(print_realtime, print_progress, translate, language, n_threads)
|
||||
self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads)
|
||||
if print_system_info:
|
||||
whisper_print_system_info()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user