#!python # cython: language_level=3 import ffmpeg import numpy as np import urllib.request import os from pathlib import Path MODELS_DIR = str(Path('~/.ggml-models').expanduser()) cimport numpy as cnp cdef int SAMPLE_RATE = 16000 cdef char* TEST_FILE = 'test.wav' cdef char* DEFAULT_MODEL = 'base' cdef char* LANGUAGE = b'en' cdef int N_THREADS = os.cpu_count() cdef _Bool PRINT_REALTIME = False cdef _Bool PRINT_PROGRESS = False cdef _Bool TOKEN_TIMESTAMPS = False cdef _Bool PRINT_TIMESTAMPS = True cdef _Bool TRANSLATE = False MODELS = { 'ggml-tiny.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin', 'ggml-tiny.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin', 'ggml-base.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.bin', 'ggml-base.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin', 'ggml-small.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.bin', 'ggml-small.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin', 'ggml-medium.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin', 'ggml-medium.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin', 'ggml-large-v1.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin', 'ggml-large.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large.bin', } def model_exists(model, models_dir=MODELS_DIR): return os.path.exists(Path(models_dir).joinpath(model)) def download_model(model, models_dir=MODELS_DIR): """Downloads ggml model with the given identifier The filenames mirror the ones given in ggerganov's repos. e.g. 'small' becomes 'ggml-small.bin' Args: model: The model identifier models_dir: The path where the file is written to """ if model_exists(model, models_dir=models_dir): return print(f'Downloading {model} to {models_dir}...') url = MODELS[model] os.makedirs(models_dir, exist_ok=True) with urllib.request.urlopen(url) as r: with open(Path(models_dir).joinpath(model), 'wb') as f: f.write(r.read()) def list_languages(): """Returns a list of tuples of language codes understood by whisper.cpp. Returns: e.g. [(0, "en"), (1, "zh"), ...] """ cdef int max_id = whisper_lang_max_id() + 1 cdef list results = [] for i in range(max_id): results.append(( i, whisper_lang_str(i).decode() )) return results cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE): try: out = ( ffmpeg.input(file, threads=0) .output( "-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr ) .run( cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True ) )[0] except Exception: raise RuntimeError(f"File '{file}' not found") cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = ( np.frombuffer(out, np.int16) .flatten() .astype(np.float32) ) / pow(2, 15) return frames cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil: cdef whisper_full_params params = whisper_full_default_params( whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY ) params.print_realtime = print_realtime params.print_progress = print_progress params.token_timestamps = token_timestamps params.print_timestamps = print_timestamps params.translate = translate params.language = language params.n_threads = n_threads return params cdef class Whisper: cdef whisper_context * ctx cdef whisper_full_params params def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs? """Constructor for Whisper class. Automatically checks for model and downloads it if necessary. Args: model: Model identifier, e.g. 'base' (see MODELS) models_dir: The path where the models should be stored print_realtime: whisper.cpp's real time transcription output print_progress: whisper.cpp's progress indicator token_timestamps: output timestamps for tokens print_timestamps: whisper.cpp's timestamped output translate: whisper.cpp's translation option language: Which language to use. Must be a byte string. n_threads: Amount of threads to use print_system_info: whisper.cpp's system info output """ model_fullname = f'ggml-{model}.bin' #.encode('utf8') download_model(model_fullname, models_dir=models_dir) model_path = Path(models_dir).joinpath(model_fullname) cdef bytes model_b = str(model_path).encode('utf8') self.ctx = whisper_init_from_file(model_b) self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads) if print_system_info: print(whisper_print_system_info().decode()) def __dealloc__(self): whisper_free(self.ctx) def transcribe(self, filename = TEST_FILE): """Transcribes from given file. Args: filename: Path to file Returns: Return value of whisper_full for extract_text(...) Raises: RuntimeError: The given file could not be found """ #print(f"Loading data from '{filename}'...") cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio(filename) #print("Transcribing..") return whisper_full(self.ctx, self.params, &frames[0], len(frames)) def extract_text(self, int res): """Extracts the text from a transcription. Args: res: A return value from transcribe(...) Returns: A list of transcribed strings. Raises: RuntimeError: The given return value was invalid. """ #print("Extracting text...") if res != 0: raise RuntimeError cdef int n_segments = whisper_full_n_segments(self.ctx) return [ whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments) ] def extract_text_and_timestamps(self, int res): """Extracts the text and timestamps from a transcription. Args: res: A return value from transcribe(...) Returns: A list of tuples containing start time, end time and transcribed text. e.g. [(0, 500, " This is a test.")] Raises: RuntimeError: The given return value was invalid. """ if res != 0: raise RuntimeError cdef int n_segments = whisper_full_n_segments(self.ctx) results = [] for i in range(n_segments): results.append(( whisper_full_get_segment_t0(self.ctx, i), whisper_full_get_segment_t1(self.ctx, i), whisper_full_get_segment_text(self.ctx, i).decode() )) return results