Added new extraction and trimming functionality
Changes from https://github.com/iantanwx/whispercpp.py: - added `transcribe_segment`, `extract_segment` and `load_audio_segment` - use `ffmpeg.Error` Other: - added `whisper_full_lang_id` - improved timestamps checking
This commit is contained in:
parent
2a4f9b3e40
commit
cc9c123251
|
@ -102,6 +102,7 @@ cdef extern from "whisper.h" nogil:
|
|||
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
|
||||
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
|
||||
cdef int whisper_full_n_segments(whisper_context*)
|
||||
cdef int whisper_full_lang_id(whisper_context*)
|
||||
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
|
||||
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
|
||||
# Unknown CtypesSpecial name='c_char_p'
|
||||
|
|
127
whispercpp.pyx
127
whispercpp.pyx
|
@ -92,8 +92,49 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
|
|||
capture_stderr=True
|
||||
)
|
||||
)[0]
|
||||
except Exception:
|
||||
raise RuntimeError(f"File '{file}' not found")
|
||||
except ffmpeg.Error as e:
|
||||
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
||||
|
||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||
np.frombuffer(out, np.int16)
|
||||
.flatten()
|
||||
.astype(np.float32)
|
||||
) / pow(2, 15)
|
||||
|
||||
return frames
|
||||
|
||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment
|
||||
if start >= end:
|
||||
raise ValueError("start must be less than end")
|
||||
|
||||
# check if start and end values fit the input
|
||||
probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index?
|
||||
|
||||
audio_stream = probe['streams'][0]
|
||||
stream_start = float(audio_stream['start_time'])
|
||||
stream_end = stream_start + float(probe['format']['duration'])
|
||||
|
||||
if stream_start > start or stream_end < end:
|
||||
raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}")
|
||||
|
||||
try:
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.filter('atrim', start=start, end=end)
|
||||
.filter('asetpts', 'PTS-STARTPTS')
|
||||
.output(
|
||||
"-", format="s16le",
|
||||
acodec="pcm_s16le",
|
||||
ac=1, ar=sr
|
||||
)
|
||||
.run(
|
||||
cmd=["ffmpeg", "-nostdin"],
|
||||
capture_stdout=True,
|
||||
capture_stderr=True
|
||||
)
|
||||
)
|
||||
except ffmpeg.Error as e:
|
||||
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
||||
|
||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||
np.frombuffer(out, np.int16)
|
||||
|
@ -120,7 +161,19 @@ cdef class Whisper:
|
|||
cdef whisper_context * ctx
|
||||
cdef whisper_full_params params
|
||||
|
||||
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
|
||||
def __init__(
|
||||
self,
|
||||
model = DEFAULT_MODEL,
|
||||
models_dir = MODELS_DIR,
|
||||
_Bool print_realtime = PRINT_REALTIME,
|
||||
_Bool print_progress = PRINT_PROGRESS,
|
||||
_Bool token_timestamps = TOKEN_TIMESTAMPS,
|
||||
_Bool print_timestamps = PRINT_TIMESTAMPS,
|
||||
_Bool translate = TRANSLATE,
|
||||
char* language = LANGUAGE,
|
||||
int n_threads = N_THREADS,
|
||||
_Bool print_system_info = False
|
||||
): # not pretty, look for a way to use kwargs?
|
||||
"""Constructor for Whisper class.
|
||||
|
||||
Automatically checks for model and downloads it if necessary.
|
||||
|
@ -156,7 +209,7 @@ cdef class Whisper:
|
|||
filename: Path to file
|
||||
|
||||
Returns:
|
||||
Return value of whisper_full for extract_text(...)
|
||||
Return value of whisper_full for extract_*(...)
|
||||
|
||||
Raises:
|
||||
RuntimeError: The given file could not be found
|
||||
|
@ -168,6 +221,26 @@ cdef class Whisper:
|
|||
#print("Transcribing..")
|
||||
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
||||
|
||||
def transcribe_segment(self, filename, start, end):
|
||||
"""Transcribes a segment from `start` to `end` from a given file.
|
||||
|
||||
Args:
|
||||
filename: Path to file
|
||||
start: Start time
|
||||
end : End time
|
||||
|
||||
Returns:
|
||||
Return value of whisper_full for extract_*(...)
|
||||
|
||||
Raises:
|
||||
RuntimeError: The given file could not be found
|
||||
ValueError: The given timestamps do not fit the audio stream
|
||||
"""
|
||||
if start >= end:
|
||||
raise ValueError("start must be less than end")
|
||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(<bytes>filename, start, end)
|
||||
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
||||
|
||||
def extract_text(self, int res):
|
||||
"""Extracts the text from a transcription.
|
||||
|
||||
|
@ -188,6 +261,52 @@ cdef class Whisper:
|
|||
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
||||
]
|
||||
|
||||
def extract_segment(self, int res, _Bool resolve_lang = True):
|
||||
"""Extracts the text from a transcription.
|
||||
|
||||
Args:
|
||||
res: A return value from transcribe(...)
|
||||
resolve_lang: Whether the language param `language` should be resolved to the auto-detected value
|
||||
|
||||
Returns:
|
||||
A dict of the following format:
|
||||
{
|
||||
'text': '',
|
||||
'start': 0.0,
|
||||
'end': 20.0,
|
||||
'segments': [
|
||||
{
|
||||
'text': '',
|
||||
'start': 0.0,
|
||||
'end': 20.0
|
||||
}, ...
|
||||
],
|
||||
'language': 'en'
|
||||
}
|
||||
|
||||
Raises:
|
||||
RuntimeError: The given return value was invalid.
|
||||
"""
|
||||
if res != 0:
|
||||
raise RuntimeError
|
||||
|
||||
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
||||
segments = []
|
||||
for i in range(n_segments):
|
||||
text = whisper_full_get_segment_text(self.ctx, i).decode()
|
||||
start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100
|
||||
end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100
|
||||
segments.append({ 'text': text, 'start': start, 'end': end })
|
||||
text = ''.join([s['text'] for s in segments])
|
||||
start = segments[0]['start']
|
||||
end = segments[-1]['end']
|
||||
language = <bytes>self.params.language
|
||||
|
||||
if resolve_lang and language == b'auto':
|
||||
language = whisper_lang_str(whisper_full_lang_id(self.ctx))
|
||||
|
||||
return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() }
|
||||
|
||||
def extract_text_and_timestamps(self, int res):
|
||||
"""Extracts the text and timestamps from a transcription.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user