Compare commits
No commits in common. "7b6fa0d8196cd579f48e7575ab5a18b88028d086" and "2a4f9b3e402cc9b1086b6f50ee447f98aec02164" have entirely different histories.
7b6fa0d819
...
2a4f9b3e40
2
setup.py
2
setup.py
|
@ -25,7 +25,7 @@ whisper_clib = ('whisper_clib', {'sources': ['whisper.cpp/ggml.c']})
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='whispercpp',
|
name='whispercpp',
|
||||||
version='1.2.0',
|
version='1.1.0',
|
||||||
description='Python bindings for whisper.cpp - ecker edition',
|
description='Python bindings for whisper.cpp - ecker edition',
|
||||||
author='lightmare',
|
author='lightmare',
|
||||||
author_email='',
|
author_email='',
|
||||||
|
|
|
@ -102,7 +102,6 @@ cdef extern from "whisper.h" nogil:
|
||||||
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
|
cdef int whisper_full(whisper_context*, whisper_full_params, float*, int)
|
||||||
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
|
cdef int whisper_full_parallel(whisper_context*, whisper_full_params, float*, int, int)
|
||||||
cdef int whisper_full_n_segments(whisper_context*)
|
cdef int whisper_full_n_segments(whisper_context*)
|
||||||
cdef int whisper_full_lang_id(whisper_context*)
|
|
||||||
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
|
cdef int64_t whisper_full_get_segment_t0(whisper_context*, int)
|
||||||
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
|
cdef int64_t whisper_full_get_segment_t1(whisper_context*, int)
|
||||||
# Unknown CtypesSpecial name='c_char_p'
|
# Unknown CtypesSpecial name='c_char_p'
|
||||||
|
|
129
whispercpp.pyx
129
whispercpp.pyx
|
@ -92,49 +92,8 @@ cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio(bytes file, int sr
|
||||||
capture_stderr=True
|
capture_stderr=True
|
||||||
)
|
)
|
||||||
)[0]
|
)[0]
|
||||||
except ffmpeg.Error as e:
|
except Exception:
|
||||||
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
raise RuntimeError(f"File '{file}' not found")
|
||||||
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
|
||||||
np.frombuffer(out, np.int16)
|
|
||||||
.flatten()
|
|
||||||
.astype(np.float32)
|
|
||||||
) / pow(2, 15)
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] load_audio_segment(bytes file, float start = 0, float end = 0, int sr = SAMPLE_RATE): # TODO: merge load_audio and load_audio_segment
|
|
||||||
if start >= end:
|
|
||||||
raise ValueError("start must be less than end")
|
|
||||||
|
|
||||||
# check if start and end values fit the input
|
|
||||||
probe = ffmpeg.probe(file, select_streams='a', ac=1) # TODO: let user specify stream index?
|
|
||||||
|
|
||||||
audio_stream = probe['streams'][0]
|
|
||||||
stream_start = float(audio_stream['start_time'])
|
|
||||||
stream_end = stream_start + float(probe['format']['duration'])
|
|
||||||
|
|
||||||
if stream_start > start or stream_end < end:
|
|
||||||
raise ValueError(f"start value {start} and end value {end} do not make sense with given file's audio stream: {stream_start} to {stream_end}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
out, _ = (
|
|
||||||
ffmpeg.input(file, threads=0)
|
|
||||||
.filter('atrim', start=start, end=end)
|
|
||||||
.filter('asetpts', 'PTS-STARTPTS')
|
|
||||||
.output(
|
|
||||||
"-", format="s16le",
|
|
||||||
acodec="pcm_s16le",
|
|
||||||
ac=1, ar=sr
|
|
||||||
)
|
|
||||||
.run(
|
|
||||||
cmd=["ffmpeg", "-nostdin"],
|
|
||||||
capture_stdout=True,
|
|
||||||
capture_stderr=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except ffmpeg.Error as e:
|
|
||||||
raise RuntimeError(f"failed to load audio: {e.stderr.decode('utf8')}") from e
|
|
||||||
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = (
|
||||||
np.frombuffer(out, np.int16)
|
np.frombuffer(out, np.int16)
|
||||||
|
@ -161,19 +120,7 @@ cdef class Whisper:
|
||||||
cdef whisper_context * ctx
|
cdef whisper_context * ctx
|
||||||
cdef whisper_full_params params
|
cdef whisper_full_params params
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model = DEFAULT_MODEL, models_dir = MODELS_DIR, _Bool print_realtime = PRINT_REALTIME, _Bool print_progress = PRINT_PROGRESS, _Bool token_timestamps = TOKEN_TIMESTAMPS, _Bool print_timestamps = PRINT_TIMESTAMPS, _Bool translate = TRANSLATE, char* language = LANGUAGE, int n_threads = N_THREADS, _Bool print_system_info = False): # not pretty, look for a way to use kwargs?
|
||||||
self,
|
|
||||||
model = DEFAULT_MODEL,
|
|
||||||
models_dir = MODELS_DIR,
|
|
||||||
_Bool print_realtime = PRINT_REALTIME,
|
|
||||||
_Bool print_progress = PRINT_PROGRESS,
|
|
||||||
_Bool token_timestamps = TOKEN_TIMESTAMPS,
|
|
||||||
_Bool print_timestamps = PRINT_TIMESTAMPS,
|
|
||||||
_Bool translate = TRANSLATE,
|
|
||||||
char* language = LANGUAGE,
|
|
||||||
int n_threads = N_THREADS,
|
|
||||||
_Bool print_system_info = False
|
|
||||||
): # not pretty, look for a way to use kwargs?
|
|
||||||
"""Constructor for Whisper class.
|
"""Constructor for Whisper class.
|
||||||
|
|
||||||
Automatically checks for model and downloads it if necessary.
|
Automatically checks for model and downloads it if necessary.
|
||||||
|
@ -209,7 +156,7 @@ cdef class Whisper:
|
||||||
filename: Path to file
|
filename: Path to file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return value of whisper_full for extract_*(...)
|
Return value of whisper_full for extract_text(...)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: The given file could not be found
|
RuntimeError: The given file could not be found
|
||||||
|
@ -220,27 +167,7 @@ cdef class Whisper:
|
||||||
|
|
||||||
#print("Transcribing..")
|
#print("Transcribing..")
|
||||||
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
||||||
|
|
||||||
def transcribe_segment(self, filename, start, end):
|
|
||||||
"""Transcribes a segment from `start` to `end` from a given file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: Path to file
|
|
||||||
start: Start time
|
|
||||||
end : End time
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return value of whisper_full for extract_*(...)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: The given file could not be found
|
|
||||||
ValueError: The given timestamps do not fit the audio stream
|
|
||||||
"""
|
|
||||||
if start >= end:
|
|
||||||
raise ValueError("start must be less than end")
|
|
||||||
cdef cnp.ndarray[cnp.float32_t, ndim=1, mode="c"] frames = load_audio_segment(<bytes>filename, start, end)
|
|
||||||
return whisper_full(self.ctx, self.params, &frames[0], len(frames))
|
|
||||||
|
|
||||||
def extract_text(self, int res):
|
def extract_text(self, int res):
|
||||||
"""Extracts the text from a transcription.
|
"""Extracts the text from a transcription.
|
||||||
|
|
||||||
|
@ -261,52 +188,6 @@ cdef class Whisper:
|
||||||
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments)
|
||||||
]
|
]
|
||||||
|
|
||||||
def extract_segment(self, int res, _Bool resolve_lang = True):
|
|
||||||
"""Extracts the text from a transcription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
res: A return value from transcribe(...)
|
|
||||||
resolve_lang: Whether the language param `language` should be resolved to the auto-detected value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of the following format:
|
|
||||||
{
|
|
||||||
'text': '',
|
|
||||||
'start': 0.0,
|
|
||||||
'end': 20.0,
|
|
||||||
'segments': [
|
|
||||||
{
|
|
||||||
'text': '',
|
|
||||||
'start': 0.0,
|
|
||||||
'end': 20.0
|
|
||||||
}, ...
|
|
||||||
],
|
|
||||||
'language': 'en'
|
|
||||||
}
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: The given return value was invalid.
|
|
||||||
"""
|
|
||||||
if res != 0:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
cdef int n_segments = whisper_full_n_segments(self.ctx)
|
|
||||||
segments = []
|
|
||||||
for i in range(n_segments):
|
|
||||||
text = whisper_full_get_segment_text(self.ctx, i).decode()
|
|
||||||
start = float(whisper_full_get_segment_t0(self.ctx, i)) / 100
|
|
||||||
end = float(whisper_full_get_segment_t1(self.ctx, i)) / 100
|
|
||||||
segments.append({ 'text': text, 'start': start, 'end': end })
|
|
||||||
text = ''.join([s['text'] for s in segments])
|
|
||||||
start = segments[0]['start']
|
|
||||||
end = segments[-1]['end']
|
|
||||||
language = <bytes>self.params.language
|
|
||||||
|
|
||||||
if resolve_lang and language == b'auto':
|
|
||||||
language = whisper_lang_str(whisper_full_lang_id(self.ctx))
|
|
||||||
|
|
||||||
return { 'text': text, 'start': start, 'end': end, 'segments': segments, 'language': language.decode() }
|
|
||||||
|
|
||||||
def extract_text_and_timestamps(self, int res):
|
def extract_text_and_timestamps(self, int res):
|
||||||
"""Extracts the text and timestamps from a transcription.
|
"""Extracts the text and timestamps from a transcription.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user