Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Rasmus Larsen 2023-01-30 17:55:42 +01:00
commit 4e4258a4b5

View File

@ -28,21 +28,21 @@ MODELS = {
}
def model_exists(model):
return os.path.exists(MODELS_DIR + "/" + model.decode())
return os.path.exists(Path(MODELS_DIR).joinpath(model))
def download_model(model):
if model_exists(model):
return
print(f'Downloading {model}...')
url = MODELS[model.decode()]
url = MODELS[model]
r = requests.get(url, allow_redirects=True)
os.makedirs(MODELS_DIR, exist_ok=True)
with open(MODELS_DIR + "/" + model.decode(), 'wb') as f:
with open(Path(MODELS_DIR).joinpath(model), 'wb') as f:
f.write(r.content)
cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE):
try:
out = (
ffmpeg.input(file, threads=0)
@ -66,11 +66,7 @@ cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
.astype(np.float32)
) / pow(2, 15)
cdef audio_data data;
data.frames = &frames[0]
data.n_frames = len(frames)
return data
return frames
cdef whisper_full_params default_params() nogil:
cdef whisper_full_params params = whisper_full_default_params(
@ -90,7 +86,8 @@ cdef class Whisper:
def __init__(self, model=DEFAULT_MODEL, pb=None):
model_fullname = f'ggml-{model}.bin'.encode('utf8')
download_model(model_fullname)
cdef bytes model_b = MODELS_DIR.encode('utf8') + b'/' + model_fullname
model_path = Path(MODELS_DIR).joinpath(model_fullname)
cdef bytes model_b = str(model_path).encode('utf8')
self.ctx = whisper_init(model_b)
self.params = default_params()
whisper_print_system_info()
@ -100,9 +97,10 @@ cdef class Whisper:
def transcribe(self, filename=TEST_FILE):
print("Loading data..")
cdef audio_data data = load_audio(<bytes>filename)
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio(<bytes>filename)
print("Transcribing..")
return whisper_full(self.ctx, self.params, data.frames, data.n_frames)
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
def extract_text(self, int res):
print("Extracting text...")