cc9c123251
Changes from https://github.com/iantanwx/whispercpp.py: - added `transcribe_segment`, `extract_segment` and `load_audio_segment` - use `ffmpeg.Error` Other: - added `whisper_full_lang_id` - improved timestamps checking
336 lines
10 KiB
Cython
336 lines
10 KiB
Cython
#!python
|
|
# cython: language_level=3
|
|
|
|
import ffmpeg
|
|
import numpy as np
|
|
import urllib.request
|
|
import os
|
|
from pathlib import Path
|
|
|
|
MODELS_DIR = str(Path('~/.ggml-models').expanduser())
|
|
|
|
|
|
cimport numpy as cnp
|
|
|
|
cdef int SAMPLE_RATE = 16000
|
|
cdef char* TEST_FILE = 'test.wav'
|
|
cdef char* DEFAULT_MODEL = 'base'
|
|
cdef char* LANGUAGE = b'en'
|
|
cdef int N_THREADS = os.cpu_count()
|
|
cdef _Bool PRINT_REALTIME = False
|
|
cdef _Bool PRINT_PROGRESS = False
|
|
cdef _Bool TOKEN_TIMESTAMPS = False
|
|
cdef _Bool PRINT_TIMESTAMPS = True
|
|
cdef _Bool TRANSLATE = False
|
|
|
|
|
|
MODELS = {
|
|
'ggml-tiny.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin',
|
|
'ggml-tiny.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
|
|
'ggml-base.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.bin',
|
|
'ggml-base.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
|
|
'ggml-small.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.bin',
|
|
'ggml-small.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin',
|
|
'ggml-medium.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin',
|
|
'ggml-medium.en.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin',
|
|
'ggml-large-v1.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin',
|
|
'ggml-large.bin': 'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-large.bin',
|
|
}
|
|
|
|
def model_exists(model, models_dir=MODELS_DIR):
|
|
return os.path.exists(Path(models_dir).joinpath(model))
|
|
|
|
def download_model(model, models_dir=MODELS_DIR):
|
|
"""Downloads ggml model with the given identifier
|
|
|
|
The filenames mirror the ones given in ggerganov's repos.
|
|
e.g. 'small' becomes 'ggml-small.bin'
|
|
|
|
Args:
|
|
model: The model identifier
|
|
models_dir: The path where the file is written to
|
|
"""
|
|
if model_exists(model, models_dir=models_dir):
|
|
return
|
|
|
|
print(f'Downloading {model} to {models_dir}...')
|
|
url = MODELS[model]
|
|
os.makedirs(models_dir, exist_ok=True)
|
|
with urllib.request.urlopen(url) as r:
|
|
with open(Path(models_dir).joinpath(model), 'wb') as f:
|
|
f.write(r.read())
|
|
|
|
|
|
def list_languages():
|
|
"""Returns a list of tuples of language codes understood by whisper.cpp.
|
|
|
|
Returns:
|
|
e.g. [(0, "en"), (1, "zh"), ...]
|
|
"""
|
|
cdef int max_id = whisper_lang_max_id() + 1
|
|
cdef list results = []
|
|
for i in range(max_id):
|
|
results.append((
|
|
i,
|
|
whisper_lang_str(i).decode()
|
|
))
|
|
return results
|
|
|
|
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr = SAMPLE_RATE):
|
|
try:
|
|
out = (
|
|
ffmpeg.input(file, threads=0)
|
|
.output(
|
|
"-", format="s16le",
|
|
acodec="pcm_s16le",
|
|
ac=1, ar=sr
|
|
)
|
|
.run(
|
|
cmd=["ffmpeg", "-nostdin"],
|
|
capture_stdout=True,
|
|
capture_stderr=True
|
|
)
|
|
)[0]
|
|
except ffmpeg.Error as e:
|
|
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
|
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
|
np.frombuffer(out, np.int16)
|
|
.flatten()
|
|
.astype(np.float32)
|
|
) / pow(2, 15)
|
|
|
|
return frames
|
|
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment
|
|
if start >= end:
|
|
raise ValueError("start must be less than end")
|
|
|
|
# check if start and end values fit the input
|
|
probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index?
|
|
|
|
audio_stream = probe['streams'][0]
|
|
stream_start = float(audio_stream['start_time'])
|
|
stream_end = stream_start + float(probe['format']['duration'])
|
|
|
|
if stream_start > start or stream_end < end:
|
|
raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}")
|
|
|
|
try:
|
|
out, _ = (
|
|
ffmpeg.input(file, threads=0)
|
|
.filter('atrim', start=start, end=end)
|
|
.filter('asetpts', 'PTS-STARTPTS')
|
|
.output(
|
|
"-", format="s16le",
|
|
acodec="pcm_s16le",
|
|
ac=1, ar=sr
|
|
)
|
|
.run(
|
|
cmd=["ffmpeg", "-nostdin"],
|
|
capture_stdout=True,
|
|
capture_stderr=True
|
|
)
|
|
)
|
|
except ffmpeg.Error as e:
|
|
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
|
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
|
np.frombuffer(out, np.int16)
|
|
.flatten()
|
|
.astype(np.float32)
|
|
) / pow(2, 15)
|
|
|
|
return frames
|
|
|
|
cdef whisper_full_params set_params(_Bool print_realtime, _Bool print_progress, _Bool token_timestamps, _Bool print_timestamps, _Bool translate, char* language, int n_threads) nogil:
|
|
cdef whisper_full_params params = whisper_full_default_params(
|
|
whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY
|
|
)
|
|
params.print_realtime = print_realtime
|
|
params.print_progress = print_progress
|
|
params.token_timestamps = token_timestamps
|
|
params.print_timestamps = print_timestamps
|
|
params.translate = translate
|
|
params.language = <const char *> language
|
|
params.n_threads = n_threads
|
|
return params
|
|
|
|
cdef class Whisper:
|
|
cdef whisper_context * ctx
|
|
cdef whisper_full_params params
|
|
|
|
def __init__(
|
|
self,
|
|
model = DEFAULT_MODEL,
|
|
models_dir = MODELS_DIR,
|
|
_Bool print_realtime = PRINT_REALTIME,
|
|
_Bool print_progress = PRINT_PROGRESS,
|
|
_Bool token_timestamps = TOKEN_TIMESTAMPS,
|
|
_Bool print_timestamps = PRINT_TIMESTAMPS,
|
|
_Bool translate = TRANSLATE,
|
|
char* language = LANGUAGE,
|
|
int n_threads = N_THREADS,
|
|
_Bool print_system_info = False
|
|
): # not pretty, look for a way to use kwargs?
|
|
"""Constructor for Whisper class.
|
|
|
|
Automatically checks for model and downloads it if necessary.
|
|
|
|
Args:
|
|
model: Model identifier, e.g. 'base' (see MODELS)
|
|
models_dir: The path where the models should be stored
|
|
print_realtime: whisper.cpp's real time transcription output
|
|
print_progress: whisper.cpp's progress indicator
|
|
token_timestamps: output timestamps for tokens
|
|
print_timestamps: whisper.cpp's timestamped output
|
|
translate: whisper.cpp's translation option
|
|
language: Which language to use. Must be a byte string.
|
|
n_threads: Amount of threads to use
|
|
print_system_info: whisper.cpp's system info output
|
|
"""
|
|
model_fullname = f'ggml-{model}.bin' #.encode('utf8')
|
|
download_model(model_fullname, models_dir=models_dir)
|
|
model_path = Path(models_dir).joinpath(model_fullname)
|
|
cdef bytes model_b = str(model_path).encode('utf8')
|
|
self.ctx = whisper_init_from_file(model_b)
|
|
self.params = set_params(print_realtime, print_progress, token_timestamps, print_timestamps, translate, language, n_threads)
|
|
if print_system_info:
|
|
print(whisper_print_system_info().decode())
|
|
|
|
def __dealloc__(self):
|
|
whisper_free(self.ctx)
|
|
|
|
def transcribe(self, filename = TEST_FILE):
|
|
"""Transcribes from given file.
|
|
|
|
Args:
|
|
filename: Path to file
|
|
|
|
Returns:
|
|
Return value of whisper_full for extract_*(...)
|
|
|
|
Raises:
|
|
RuntimeError: The given file could not be found
|
|
"""
|
|
|
|
#print(f"Loading data from '{filename}'...")
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio(<bytes>filename)
|
|
|
|
#print("Transcribing..")
|
|
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
|
|
|
def transcribe_segment(self, filename, start, end):
|
|
"""Transcribes a segment from `start` to `end` from a given file.
|
|
|
|
Args:
|
|
filename: Path to file
|
|
start: Start time
|
|
end : End time
|
|
|
|
Returns:
|
|
Return value of whisper_full for extract_*(...)
|
|
|
|
Raises:
|
|
RuntimeError: The given file could not be found
|
|
ValueError: The given timestamps do not fit the audio stream
|
|
"""
|
|
if start >= end:
|
|
raise ValueError("start must be less than end")
|
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(<bytes>filename, start, end)
|
|
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
|
|
|
def extract_text(self, int res):
|
|
"""Extracts the text from a transcription.
|
|
|
|
Args:
|
|
res: A return value from transcribe(...)
|
|
|
|
Returns:
|
|
A list of transcribed strings.
|
|
|
|
Raises:
|
|
RuntimeError: The given return value was invalid.
|
|
"""
|
|
#print("Extracting text...")
|
|
if res != 0:
|
|
raise RuntimeError
|
|
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
|
return [
|
|
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
|
]
|
|
|
|
def extract_segment(self, int res, _Bool resolve_lang = True):
|
|
"""Extracts the text from a transcription.
|
|
|
|
Args:
|
|
res: A return value from transcribe(...)
|
|
resolve_lang: Whether the language param `language` should be resolved to the auto-detected value
|
|
|
|
Returns:
|
|
A dict of the following format:
|
|
{
|
|
'text': '',
|
|
'start': 0.0,
|
|
'end': 20.0,
|
|
'segments': [
|
|
{
|
|
'text': '',
|
|
'start': 0.0,
|
|
'end': 20.0
|
|
}, ...
|
|
],
|
|
'language': 'en'
|
|
}
|
|
|
|
Raises:
|
|
RuntimeError: The given return value was invalid.
|
|
"""
|
|
if res != 0:
|
|
raise RuntimeError
|
|
|
|
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
|
segments = []
|
|
for i in range(n_segments):
|
|
text = whisper_full_get_segment_text(self.ctx, i).decode()
|
|
start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100
|
|
end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100
|
|
segments.append({ 'text': text, 'start': start, 'end': end })
|
|
text = ''.join([s['text'] for s in segments])
|
|
start = segments[0]['start']
|
|
end = segments[-1]['end']
|
|
language = <bytes>self.params.language
|
|
|
|
if resolve_lang and language == b'auto':
|
|
language = whisper_lang_str(whisper_full_lang_id(self.ctx))
|
|
|
|
return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() }
|
|
|
|
def extract_text_and_timestamps(self, int res):
|
|
"""Extracts the text and timestamps from a transcription.
|
|
|
|
Args:
|
|
res: A return value from transcribe(...)
|
|
|
|
Returns:
|
|
A list of tuples containing start time, end time and transcribed text.
|
|
e.g. [(0, 500, " This is a test.")]
|
|
|
|
Raises:
|
|
RuntimeError: The given return value was invalid.
|
|
"""
|
|
if res != 0:
|
|
raise RuntimeError
|
|
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
|
results = []
|
|
for i in range(n_segments):
|
|
results.append((
|
|
whisper_full_get_segment_t0(self.ctx, i),
|
|
whisper_full_get_segment_t1(self.ctx, i),
|
|
whisper_full_get_segment_text(self.ctx, i).decode()
|
|
))
|
|
return results
|
|
|
|
|