Added new extraction and trimming functionality
Changes from https://github.com/iantanwx/whispercpp.py: - added `transcribe_segment`, `extract_segment` and `load_audio_segment` - use `ffmpeg.Error` Other: - added `whisper_full_lang_id` - improved timestamps checking
This commit is contained in:
parent
2a4f9b3e40
commit
cc9c123251
|
@ -102,6 +102,7 @@ cdef extern from "whisper.h" nogil:
|
||||||
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
|
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
|
||||||
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
|
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
|
||||||
cdef int whisper_full_n_segments(whisper_context*)
|
cdef int whisper_full_n_segments(whisper_context*)
|
||||||
|
cdef int whisper_full_lang_id(whisper_context*)
|
||||||
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
|
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
|
||||||
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
|
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
|
||||||
# Unknown CtypesSpecial name='c_char_p'
|
# Unknown CtypesSpecial name='c_char_p'
|
||||||
|
|
129
whispercpp.pyx
129
whispercpp.pyx
|
@ -92,8 +92,49 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
|
||||||
capture_stderr=True
|
capture_stderr=True
|
||||||
)
|
)
|
||||||
)[0]
|
)[0]
|
||||||
except Exception:
|
except ffmpeg.Error as e:
|
||||||
raise RuntimeError(f"File '{file}' not found")
|
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
||||||
|
|
||||||
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||||
|
np.frombuffer(out, np.int16)
|
||||||
|
.flatten()
|
||||||
|
.astype(np.float32)
|
||||||
|
) / pow(2, 15)
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment
|
||||||
|
if start >= end:
|
||||||
|
raise ValueError("start must be less than end")
|
||||||
|
|
||||||
|
# check if start and end values fit the input
|
||||||
|
probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index?
|
||||||
|
|
||||||
|
audio_stream = probe['streams'][0]
|
||||||
|
stream_start = float(audio_stream['start_time'])
|
||||||
|
stream_end = stream_start + float(probe['format']['duration'])
|
||||||
|
|
||||||
|
if stream_start > start or stream_end < end:
|
||||||
|
raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
out, _ = (
|
||||||
|
ffmpeg.input(file, threads=0)
|
||||||
|
.filter('atrim', start=start, end=end)
|
||||||
|
.filter('asetpts', 'PTS-STARTPTS')
|
||||||
|
.output(
|
||||||
|
"-", format="s16le",
|
||||||
|
acodec="pcm_s16le",
|
||||||
|
ac=1, ar=sr
|
||||||
|
)
|
||||||
|
.run(
|
||||||
|
cmd=["ffmpeg", "-nostdin"],
|
||||||
|
capture_stdout=True,
|
||||||
|
capture_stderr=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ffmpeg.Error as e:
|
||||||
|
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
||||||
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||||
np.frombuffer(out, np.int16)
|
np.frombuffer(out, np.int16)
|
||||||
|
@ -120,7 +161,19 @@ cdef class Whisper:
|
||||||
cdef whisper_context * ctx
|
cdef whisper_context * ctx
|
||||||
cdef whisper_full_params params
|
cdef whisper_full_params params
|
||||||
|
|
||||||
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
|
def __init__(
|
||||||
|
self,
|
||||||
|
model = DEFAULT_MODEL,
|
||||||
|
models_dir = MODELS_DIR,
|
||||||
|
_Bool print_realtime = PRINT_REALTIME,
|
||||||
|
_Bool print_progress = PRINT_PROGRESS,
|
||||||
|
_Bool token_timestamps = TOKEN_TIMESTAMPS,
|
||||||
|
_Bool print_timestamps = PRINT_TIMESTAMPS,
|
||||||
|
_Bool translate = TRANSLATE,
|
||||||
|
char* language = LANGUAGE,
|
||||||
|
int n_threads = N_THREADS,
|
||||||
|
_Bool print_system_info = False
|
||||||
|
): # not pretty, look for a way to use kwargs?
|
||||||
"""Constructor for Whisper class.
|
"""Constructor for Whisper class.
|
||||||
|
|
||||||
Automatically checks for model and downloads it if necessary.
|
Automatically checks for model and downloads it if necessary.
|
||||||
|
@ -156,7 +209,7 @@ cdef class Whisper:
|
||||||
filename: Path to file
|
filename: Path to file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return value of whisper_full for extract_text(...)
|
Return value of whisper_full for extract_*(...)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: The given file could not be found
|
RuntimeError: The given file could not be found
|
||||||
|
@ -167,7 +220,27 @@ cdef class Whisper:
|
||||||
|
|
||||||
#print("Transcribing..")
|
#print("Transcribing..")
|
||||||
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
||||||
|
|
||||||
|
def transcribe_segment(self, filename, start, end):
|
||||||
|
"""Transcribes a segment from `start` to `end` from a given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to file
|
||||||
|
start: Start time
|
||||||
|
end : End time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return value of whisper_full for extract_*(...)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: The given file could not be found
|
||||||
|
ValueError: The given timestamps do not fit the audio stream
|
||||||
|
"""
|
||||||
|
if start >= end:
|
||||||
|
raise ValueError("start must be less than end")
|
||||||
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(<bytes>filename, start, end)
|
||||||
|
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
||||||
|
|
||||||
def extract_text(self, int res):
|
def extract_text(self, int res):
|
||||||
"""Extracts the text from a transcription.
|
"""Extracts the text from a transcription.
|
||||||
|
|
||||||
|
@ -188,6 +261,52 @@ cdef class Whisper:
|
||||||
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def extract_segment(self, int res, _Bool resolve_lang = True):
|
||||||
|
"""Extracts the text from a transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
res: A return value from transcribe(...)
|
||||||
|
resolve_lang: Whether the language param `language` should be resolved to the auto-detected value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict of the following format:
|
||||||
|
{
|
||||||
|
'text': '',
|
||||||
|
'start': 0.0,
|
||||||
|
'end': 20.0,
|
||||||
|
'segments': [
|
||||||
|
{
|
||||||
|
'text': '',
|
||||||
|
'start': 0.0,
|
||||||
|
'end': 20.0
|
||||||
|
}, ...
|
||||||
|
],
|
||||||
|
'language': 'en'
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: The given return value was invalid.
|
||||||
|
"""
|
||||||
|
if res != 0:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
||||||
|
segments = []
|
||||||
|
for i in range(n_segments):
|
||||||
|
text = whisper_full_get_segment_text(self.ctx, i).decode()
|
||||||
|
start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100
|
||||||
|
end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100
|
||||||
|
segments.append({ 'text': text, 'start': start, 'end': end })
|
||||||
|
text = ''.join([s['text'] for s in segments])
|
||||||
|
start = segments[0]['start']
|
||||||
|
end = segments[-1]['end']
|
||||||
|
language = <bytes>self.params.language
|
||||||
|
|
||||||
|
if resolve_lang and language == b'auto':
|
||||||
|
language = whisper_lang_str(whisper_full_lang_id(self.ctx))
|
||||||
|
|
||||||
|
return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() }
|
||||||
|
|
||||||
def extract_text_and_timestamps(self, int res):
|
def extract_text_and_timestamps(self, int res):
|
||||||
"""Extracts the text and timestamps from a transcription.
|
"""Extracts the text and timestamps from a transcription.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user