From cc9c1232510244c262b56846866ac7a3a220d3f2 Mon Sep 17 00:00:00 2001 From: lightmare Date: Mon, 6 Mar 2023 08:47:23 +0000 Subject: [PATCH] Added new extraction and trimming functionality Changes from https://github.com/iantanwx/whispercpp.py: - added `transcribe_segment`, `extract_segment` and `load_audio_segment` - use `ffmpeg.Error` Other: - added `whisper_full_lang_id` - improved timestamps checking --- whispercpp.pxd | 1 + whispercpp.pyx | 129 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/whispercpp.pxd b/whispercpp.pxd index 5c55771..7df4b30 100644 --- a/whispercpp.pxd +++ b/whispercpp.pxd @@ -102,6 +102,7 @@ cdef extern from "whisper.h" nogil: cdef int whisper_full(whisper_context*, whisper_full_params, float*, int) cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int) cdef int whisper_full_n_segments(whisper_context*) + cdef int whisper_full_lang_id(whisper_context*) cdef int64_t whisper_full_get_segment_t0(whisper_context*, int) cdef int64_t whisper_full_get_segment_t1(whisper_context*, int) # Unknown CtypesSpecial name='c_char_p' diff --git a/whispercpp.pyx b/whispercpp.pyx index 5e33243..310aadf 100644 --- a/whispercpp.pyx +++ b/whispercpp.pyx @@ -92,8 +92,49 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr capture_stderr=True ) )[0] - except Exception: - raise RuntimeError(f"File '{file}' not found") + except ffmpeg.Error as e: + raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e + + cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = ( + np.frombuffer(out, np.int16) + .flatten() + .astype(np.float32) + ) / pow(2, 15) + + return frames + +cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment + if start >= end: + raise ValueError("start must be less than end") + + # check if start and end values fit the input + probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index? + + audio_stream = probe['streams'][0] + stream_start = float(audio_stream['start_time']) + stream_end = stream_start + float(probe['format']['duration']) + + if stream_start > start or stream_end < end: + raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}") + + try: + out, _ = ( + ffmpeg.input(file, threads=0) + .filter('atrim', start=start, end=end) + .filter('asetpts', 'PTS-STARTPTS') + .output( + "-", format="s16le", + acodec="pcm_s16le", + ac=1, ar=sr + ) + .run( + cmd=["ffmpeg", "-nostdin"], + capture_stdout=True, + capture_stderr=True + ) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = ( np.frombuffer(out, np.int16) @@ -120,7 +161,19 @@ cdef class Whisper: cdef whisper_context * ctx cdef whisper_full_params params - def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs? + def __init__( + self, + model = DEFAULT_MODEL, + models_dir = MODELS_DIR, + _Bool print_realtime = PRINT_REALTIME, + _Bool print_progress = PRINT_PROGRESS, + _Bool token_timestamps = TOKEN_TIMESTAMPS, + _Bool print_timestamps = PRINT_TIMESTAMPS, + _Bool translate = TRANSLATE, + char* language = LANGUAGE, + int n_threads = N_THREADS, + _Bool print_system_info = False + ): # not pretty, look for a way to use kwargs? """Constructor for Whisper class. Automatically checks for model and downloads it if necessary. @@ -156,7 +209,7 @@ cdef class Whisper: filename: Path to file Returns: - Return value of whisper_full for extract_text(...) + Return value of whisper_full for extract_*(...) Raises: RuntimeError: The given file could not be found @@ -167,7 +220,27 @@ cdef class Whisper: #print("Transcribing..") return whisper_full(self.ctx, self.params, &frames[0], len(frames)) - + + def transcribe_segment(self, filename, start, end): + """Transcribes a segment from `start` to `end` from a given file. + + Args: + filename: Path to file + start: Start time + end : End time + + Returns: + Return value of whisper_full for extract_*(...) + + Raises: + RuntimeError: The given file could not be found + ValueError: The given timestamps do not fit the audio stream + """ + if start >= end: + raise ValueError("start must be less than end") + cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(filename, start, end) + return whisper_full(self.ctx, self.params, &frames[0], len(frames)) + def extract_text(self, int res): """Extracts the text from a transcription. @@ -188,6 +261,52 @@ cdef class Whisper: whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments) ] + def extract_segment(self, int res, _Bool resolve_lang = True): + """Extracts the text from a transcription. + + Args: + res: A return value from transcribe(...) + resolve_lang: Whether the language param `language` should be resolved to the auto-detected value + + Returns: + A dict of the following format: + { + 'text': '', + 'start': 0.0, + 'end': 20.0, + 'segments': [ + { + 'text': '', + 'start': 0.0, + 'end': 20.0 + }, ... + ], + 'language': 'en' + } + + Raises: + RuntimeError: The given return value was invalid. + """ + if res != 0: + raise RuntimeError + + cdef int n_segments = whisper_full_n_segments(self.ctx) + segments = [] + for i in range(n_segments): + text = whisper_full_get_segment_text(self.ctx, i).decode() + start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100 + end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100 + segments.append({ 'text': text, 'start': start, 'end': end }) + text = ''.join([s['text'] for s in segments]) + start = segments[0]['start'] + end = segments[-1]['end'] + language = self.params.language + + if resolve_lang and language == b'auto': + language = whisper_lang_str(whisper_full_lang_id(self.ctx)) + + return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() } + def extract_text_and_timestamps(self, int res): """Extracts the text and timestamps from a transcription.