From f73968e1579aca558841175dbe6e83a974e1dcdc Mon Sep 17 00:00:00 2001
From: Luke Southam <luke@devthe.com>
Date: Sun, 11 Dec 2022 00:06:57 +0000
Subject: [PATCH] Added download method.

---
 .vscode/settings.json |  7 -----
 setup.py              | 15 ++++------
 whispercpp.pyx        | 65 +++++++++++++++++++++++++++++++------------
 3 files changed, 53 insertions(+), 34 deletions(-)
 delete mode 100644 .vscode/settings.json

diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index bc6d20b..0000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,7 +0,0 @@
-{
-    "workbench.colorCustomizations": {
-        "activityBar.background": "#053239",
-        "titleBar.activeBackground": "#074750",
-        "titleBar.activeForeground": "#F2FCFE"
-    }
-}
\ No newline at end of file
diff --git a/setup.py b/setup.py
index e84b4fb..9e15e2d 100644
--- a/setup.py
+++ b/setup.py
@@ -1,17 +1,14 @@
 from distutils.core import setup
 from Cython.Build import cythonize
-import numpy
-import os
-import sys
+import numpy, os, sys
 
 if sys.platform == 'darwin':
-    os.environ['CFLAGS'] = '-DGGML_USE_ACCELERATE'
-    os.environ['CXXFLAGS'] = '-DGGML_USE_ACCELERATE'
-    os.environ['LDFLAGS'] = '-framework Accelerate'
+    os.environ['CFLAGS']   = '-DGGML_USE_ACCELERATE -O3'
+    os.environ['CXXFLAGS'] = '-DGGML_USE_ACCELERATE -O3'
+    os.environ['LDFLAGS']  = '-framework Accelerate'
 else:
-    os.environ['CFLAGS'] = '-mavx -mavx2 -mfma -mf16c'
-    os.environ['CXXFLAGS'] = '-mavx -mavx2 -mfma -mf16c'
-
+    os.environ['CFLAGS']   = '-mavx -mavx2 -mfma -mf16c -O3'
+    os.environ['CXXFLAGS'] = '-mavx -mavx2 -mfma -mf16c -O3'
 
 setup(
     name='whispercpp',
diff --git a/whispercpp.pyx b/whispercpp.pyx
index 67c455a..d49bdad 100644
--- a/whispercpp.pyx
+++ b/whispercpp.pyx
@@ -5,27 +5,55 @@
 
 import ffmpeg
 import numpy as np
+import requests
+import os
+
 cimport numpy as cnp
 
 cdef int SAMPLE_RATE = 16000
 cdef char* TEST_FILE = b'test.wav'
-cdef char* DEFAULT_MODEL = b'ggml-tiny.bin'
+cdef char* DEFAULT_MODEL = b'model_ggml_tiny.bin'
 cdef char* LANGUAGE = b'fr'
 
+MODELS = {
+    b'model_ggml_tiny.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-tiny.bin',
+    b'model_ggml_base.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-base.bin',
+    b'model_ggml_small.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-small.bin',
+    b'model_ggml_medium.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-medium.bin',
+    b'model_ggml_large.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-large.bin',
+}
+
+def model_exists(model):
+    return os.path.exists(model)
+
+def download_model(model):
+    if model_exists(model):
+        return
+
+    print('Downloading model...')
+    url = MODELS[model]
+    r = requests.get(url, allow_redirects=True)
+    with open(model, 'wb') as f:
+        f.write(r.content)
+
+
 cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE):
-    out = (
-        ffmpeg.input(file, threads=0)
-        .output(
-            "-", format="s16le",
-            acodec="pcm_s16le",
-            ac=1, ar=sr
-        )
-        .run(
-            cmd=["ffmpeg", "-nostdin"],
-            capture_stdout=True,
-            capture_stderr=True
-        )
-    )[0]
+    try:
+        out = (
+            ffmpeg.input(file, threads=0)
+            .output(
+                "-", format="s16le",
+                acodec="pcm_s16le",
+                ac=1, ar=sr
+            )
+            .run(
+                cmd=["ffmpeg", "-nostdin"],
+                capture_stdout=True,
+                capture_stderr=True
+            )
+        )[0]
+    except:
+        raise RuntimeError('File not found')
 
     cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
         np.frombuffer(out, np.int16)
@@ -55,15 +83,18 @@ cdef class Whisper:
     cdef whisper_full_params params
 
     def __init__(self, char* model=DEFAULT_MODEL):
+        download_model(model)
         self.ctx = whisper_init(model)
         self.params = default_params()
 
     def __dealloc__(self):
         whisper_free(self.ctx)
 
-    cpdef str transcribe(self):
+    def transcribe(self):
         cdef audio_data data = load_audio(TEST_FILE) 
-        cdef int res = whisper_full(self.ctx, self.params, data.frames, data.n_frames)
+        return whisper_full(self.ctx, self.params, data.frames, data.n_frames)
+    
+    cpdef str extract_text(self, int res):
         if res != 0:
             raise RuntimeError
         cdef int n_segments = whisper_full_n_segments(self.ctx)
@@ -72,5 +103,3 @@ cdef class Whisper:
         ]).decode()
 
 
-
-