This commit is contained in:
Luke Southam 2022-12-11 06:02:45 +00:00
parent 02429c527e
commit c7a8d72e70
2 changed files with 17 additions and 16 deletions

View File

@ -1,9 +1,8 @@
#!python
# cython: language_level=3
from libc.stdint cimport int64_t
cdef:
cdef nogil:
int WHISPER_SAMPLE_RATE = 16000
int WHISPER_N_FFT = 400
int WHISPER_N_MEL = 80
@ -13,6 +12,9 @@ cdef:
char* TEST_FILE = b'test.wav'
char* DEFAULT_MODEL = b'ggml-tiny.bin'
char* LANGUAGE = b'fr'
ctypedef struct audio_data:
float* frames;
int n_frames;
cdef extern from "whisper.h" nogil:
enum whisper_sampling_strategy:
@ -109,8 +111,3 @@ cdef extern from "whisper.h" nogil:
const char* whisper_print_system_info()
const char* whisper_full_get_segment_text(whisper_context*, int)
ctypedef struct audio_data:
float* frames;
int n_frames;

View File

@ -8,12 +8,14 @@ import numpy as np
import requests
import os
cimport numpy as cnp
cdef int SAMPLE_RATE = 16000
cdef char* TEST_FILE = b'test.wav'
cdef char* TEST_FILE = 'test.wav'
cdef char* DEFAULT_MODEL = 'tiny'
cdef char* LANGUAGE = b'fr'
cdef int N_THREADS = os.cpu_count()
MODELS = {
'model_ggml_tiny.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-tiny.bin',
@ -67,7 +69,7 @@ cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
return data
cdef whisper_full_params default_params():
cdef whisper_full_params default_params() nogil:
cdef whisper_full_params params = whisper_full_default_params(
whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
)
@ -75,6 +77,7 @@ cdef whisper_full_params default_params():
params.print_progress = True
params.translate = False
params.language = <const char *> LANGUAGE
n_threads = N_THREADS
return params
@ -83,25 +86,26 @@ cdef class Whisper:
cdef whisper_full_params params
def __init__(self, model=DEFAULT_MODEL, pb=None):
model_fullname = f'model_ggml_{model.decode()}.bin'.encode('utf8')
model_fullname = f'model_ggml_{model}.bin'.encode('utf8')
download_model(model_fullname)
cdef bytes model_b = model_fullname
self.ctx = whisper_init(model_b)
self.params = default_params()
whisper_print_system_info()
def __dealloc__(self):
whisper_free(self.ctx)
def transcribe(self):
cdef audio_data data = load_audio(TEST_FILE)
def transcribe(self, filename=TEST_FILE):
cdef audio_data data = load_audio(<bytes>filename)
return whisper_full(self.ctx, self.params, data.frames, data.n_frames)
cpdef str extract_text(self, int res):
cpdef list extract_text(self, int res):
if res != 0:
raise RuntimeError
cdef int n_segments = whisper_full_n_segments(self.ctx)
return b'\n'.join([
whisper_full_get_segment_text(self.ctx, i) for i in range(n_segments)
]).decode()
return [
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
]