Added download method.
This commit is contained in:
parent
295118d41c
commit
f73968e157
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
|
@ -1,7 +0,0 @@
|
||||||
{
|
|
||||||
"workbench.colorCustomizations": {
|
|
||||||
"activityBar.background": "#053239",
|
|
||||||
"titleBar.activeBackground": "#074750",
|
|
||||||
"titleBar.activeForeground": "#F2FCFE"
|
|
||||||
}
|
|
||||||
}
|
|
13
setup.py
13
setup.py
|
@ -1,17 +1,14 @@
|
||||||
from distutils.core import setup
|
from distutils.core import setup
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
import numpy
|
import numpy, os, sys
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if sys.platform == 'darwin':
|
if sys.platform == 'darwin':
|
||||||
os.environ['CFLAGS'] = '-DGGML_USE_ACCELERATE'
|
os.environ['CFLAGS'] = '-DGGML_USE_ACCELERATE -O3'
|
||||||
os.environ['CXXFLAGS'] = '-DGGML_USE_ACCELERATE'
|
os.environ['CXXFLAGS'] = '-DGGML_USE_ACCELERATE -O3'
|
||||||
os.environ['LDFLAGS'] = '-framework Accelerate'
|
os.environ['LDFLAGS'] = '-framework Accelerate'
|
||||||
else:
|
else:
|
||||||
os.environ['CFLAGS'] = '-mavx -mavx2 -mfma -mf16c'
|
os.environ['CFLAGS'] = '-mavx -mavx2 -mfma -mf16c -O3'
|
||||||
os.environ['CXXFLAGS'] = '-mavx -mavx2 -mfma -mf16c'
|
os.environ['CXXFLAGS'] = '-mavx -mavx2 -mfma -mf16c -O3'
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='whispercpp',
|
name='whispercpp',
|
||||||
|
|
|
@ -5,14 +5,40 @@
|
||||||
|
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
|
||||||
cimport numpy as cnp
|
cimport numpy as cnp
|
||||||
|
|
||||||
cdef int SAMPLE_RATE = 16000
|
cdef int SAMPLE_RATE = 16000
|
||||||
cdef char* TEST_FILE = b'test.wav'
|
cdef char* TEST_FILE = b'test.wav'
|
||||||
cdef char* DEFAULT_MODEL = b'ggml-tiny.bin'
|
cdef char* DEFAULT_MODEL = b'model_ggml_tiny.bin'
|
||||||
cdef char* LANGUAGE = b'fr'
|
cdef char* LANGUAGE = b'fr'
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
b'model_ggml_tiny.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-tiny.bin',
|
||||||
|
b'model_ggml_base.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-base.bin',
|
||||||
|
b'model_ggml_small.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-small.bin',
|
||||||
|
b'model_ggml_medium.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-medium.bin',
|
||||||
|
b'model_ggml_large.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-large.bin',
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_exists(model):
|
||||||
|
return os.path.exists(model)
|
||||||
|
|
||||||
|
def download_model(model):
|
||||||
|
if model_exists(model):
|
||||||
|
return
|
||||||
|
|
||||||
|
print('Downloading model...')
|
||||||
|
url = MODELS[model]
|
||||||
|
r = requests.get(url, allow_redirects=True)
|
||||||
|
with open(model, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
|
||||||
cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
|
cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
|
||||||
|
try:
|
||||||
out = (
|
out = (
|
||||||
ffmpeg.input(file, threads=0)
|
ffmpeg.input(file, threads=0)
|
||||||
.output(
|
.output(
|
||||||
|
@ -26,6 +52,8 @@ cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
|
||||||
capture_stderr=True
|
capture_stderr=True
|
||||||
)
|
)
|
||||||
)[0]
|
)[0]
|
||||||
|
except:
|
||||||
|
raise RuntimeError('File not found')
|
||||||
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||||
np.frombuffer(out, np.int16)
|
np.frombuffer(out, np.int16)
|
||||||
|
@ -55,15 +83,18 @@ cdef class Whisper:
|
||||||
cdef whisper_full_params params
|
cdef whisper_full_params params
|
||||||
|
|
||||||
def __init__(self, char* model=DEFAULT_MODEL):
|
def __init__(self, char* model=DEFAULT_MODEL):
|
||||||
|
download_model(model)
|
||||||
self.ctx = whisper_init(model)
|
self.ctx = whisper_init(model)
|
||||||
self.params = default_params()
|
self.params = default_params()
|
||||||
|
|
||||||
def __dealloc__(self):
|
def __dealloc__(self):
|
||||||
whisper_free(self.ctx)
|
whisper_free(self.ctx)
|
||||||
|
|
||||||
cpdef str transcribe(self):
|
def transcribe(self):
|
||||||
cdef audio_data data = load_audio(TEST_FILE)
|
cdef audio_data data = load_audio(TEST_FILE)
|
||||||
cdef int res = whisper_full(self.ctx, self.params, data.frames, data.n_frames)
|
return whisper_full(self.ctx, self.params, data.frames, data.n_frames)
|
||||||
|
|
||||||
|
cpdef str extract_text(self, int res):
|
||||||
if res != 0:
|
if res != 0:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
||||||
|
@ -72,5 +103,3 @@ cdef class Whisper:
|
||||||
]).decode()
|
]).decode()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user