Added new extraction and trimming functionality

Changes from https://github.com/iantanwx/whispercpp.py:
- added `transcribe_segment`, `extract_segment` and `load_audio_segment`
- use `ffmpeg.Error`
Other:
- added `whisper_full_lang_id`
- improved timestamps checking
This commit is contained in:
lightmare 2023-03-06 08:47:23 +00:00
parent 2a4f9b3e40
commit cc9c123251
2 changed files with 125 additions and 5 deletions

View File

@ -102,6 +102,7 @@ cdef extern from "whisper.h" nogil:
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int) cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int) cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
cdef int whisper_full_n_segments(whisper_context*) cdef int whisper_full_n_segments(whisper_context*)
cdef int whisper_full_lang_id(whisper_context*)
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int) cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int) cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
# Unknown CtypesSpecial name='c_char_p' # Unknown CtypesSpecial name='c_char_p'

View File

@ -92,8 +92,49 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
capture_stderr=True capture_stderr=True
) )
)[0] )[0]
except Exception: except ffmpeg.Error as e:
raise RuntimeError(f"File '{file}' not found") raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
np.frombuffer(out, np.int16)
.flatten()
.astype(np.float32)
) / pow(2, 15)
return frames
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment
if start >= end:
raise ValueError("start must be less than end")
# check if start and end values fit the input
probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index?
audio_stream = probe['streams'][0]
stream_start = float(audio_stream['start_time'])
stream_end = stream_start + float(probe['format']['duration'])
if stream_start > start or stream_end < end:
raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}")
try:
out, _ = (
ffmpeg.input(file, threads=0)
.filter('atrim', start=start, end=end)
.filter('asetpts', 'PTS-STARTPTS')
.output(
"-", format="s16le",
acodec="pcm_s16le",
ac=1, ar=sr
)
.run(
cmd=["ffmpeg", "-nostdin"],
capture_stdout=True,
capture_stderr=True
)
)
except ffmpeg.Error as e:
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = ( cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
np.frombuffer(out, np.int16) np.frombuffer(out, np.int16)
@ -120,7 +161,19 @@ cdef class Whisper:
cdef whisper_context * ctx cdef whisper_context * ctx
cdef whisper_full_params params cdef whisper_full_params params
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs? def __init__(
self,
model = DEFAULT_MODEL,
models_dir = MODELS_DIR,
_Bool print_realtime = PRINT_REALTIME,
_Bool print_progress = PRINT_PROGRESS,
_Bool token_timestamps = TOKEN_TIMESTAMPS,
_Bool print_timestamps = PRINT_TIMESTAMPS,
_Bool translate = TRANSLATE,
char* language = LANGUAGE,
int n_threads = N_THREADS,
_Bool print_system_info = False
): # not pretty, look for a way to use kwargs?
"""Constructor for Whisper class. """Constructor for Whisper class.
Automatically checks for model and downloads it if necessary. Automatically checks for model and downloads it if necessary.
@ -156,7 +209,7 @@ cdef class Whisper:
filename: Path to file filename: Path to file
Returns: Returns:
Return value of whisper_full for extract_text(...) Return value of whisper_full for extract_*(...)
Raises: Raises:
RuntimeError: The given file could not be found RuntimeError: The given file could not be found
@ -167,7 +220,27 @@ cdef class Whisper:
#print("Transcribing..") #print("Transcribing..")
return whisper_full(self.ctx, self.params, &frames[0], len(frames)) return whisper_full(self.ctx, self.params, &frames[0], len(frames))
def transcribe_segment(self, filename, start, end):
"""Transcribes a segment from `start` to `end` from a given file.
Args:
filename: Path to file
start: Start time
end : End time
Returns:
Return value of whisper_full for extract_*(...)
Raises:
RuntimeError: The given file could not be found
ValueError: The given timestamps do not fit the audio stream
"""
if start >= end:
raise ValueError("start must be less than end")
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(<bytes>filename, start, end)
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
def extract_text(self, int res): def extract_text(self, int res):
"""Extracts the text from a transcription. """Extracts the text from a transcription.
@ -188,6 +261,52 @@ cdef class Whisper:
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments) whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
] ]
def extract_segment(self, int res, _Bool resolve_lang = True):
"""Extracts the text from a transcription.
Args:
res: A return value from transcribe(...)
resolve_lang: Whether the language param `language` should be resolved to the auto-detected value
Returns:
A dict of the following format:
{
'text': '',
'start': 0.0,
'end': 20.0,
'segments': [
{
'text': '',
'start': 0.0,
'end': 20.0
}, ...
],
'language': 'en'
}
Raises:
RuntimeError: The given return value was invalid.
"""
if res != 0:
raise RuntimeError
cdef int n_segments = whisper_full_n_segments(self.ctx)
segments = []
for i in range(n_segments):
text = whisper_full_get_segment_text(self.ctx, i).decode()
start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100
end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100
segments.append({ 'text': text, 'start': start, 'end': end })
text = ''.join([s['text'] for s in segments])
start = segments[0]['start']
end = segments[-1]['end']
language = <bytes>self.params.language
if resolve_lang and language == b'auto':
language = whisper_lang_str(whisper_full_lang_id(self.ctx))
return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() }
def extract_text_and_timestamps(self, int res): def extract_text_and_timestamps(self, int res):
"""Extracts the text and timestamps from a transcription. """Extracts the text and timestamps from a transcription.