updated whisperX integration for use with the latest version (v3) (NOTE: you WILL need to also update whisperx if you pull this commit)

This commit is contained in:
mrq 2023-06-09 02:41:29 +00:00
parent 805d7d35e8
commit e227ab8e08

View File

@ -177,8 +177,6 @@ webui = None
voicefixer = None voicefixer = None
whisper_model = None whisper_model = None
whisper_vad = None
whisper_diarize = None
whisper_align_model = None whisper_align_model = None
training_state = None training_state = None
@ -1999,8 +1997,6 @@ def whisper_sanitize( results ):
def whisper_transcribe( file, language=None ): def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety # shouldn't happen, but it's for safety
global whisper_model global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model global whisper_align_model
if not whisper_model: if not whisper_model:
@ -2035,37 +2031,12 @@ def whisper_transcribe( file, language=None ):
if args.whisper_backend == "m-bain/whisperx": if args.whisper_backend == "m-bain/whisperx":
import whisperx import whisperx
from whisperx.diarize import assign_word_speakers
device = "cuda" if get_device_name() == "cuda" else "cpu" device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad: result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize)
# omits a considerable amount of the end
if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
else:
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
else:
result = whisper_model.transcribe(file)
align_model, metadata = whisper_align_model align_model, metadata = whisper_align_model
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False)
if whisper_diarize:
diarize_segments = whisper_diarize(file)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
for i in range(len(result_aligned['segments'])):
del result_aligned['segments'][i]['word-segments']
del result_aligned['segments'][i]['char-segments']
result['segments'] = result_aligned['segments'] result['segments'] = result_aligned['segments']
result['text'] = [] result['text'] = []
@ -3625,8 +3596,6 @@ def unload_voicefixer():
def load_whisper_model(language=None, model_name=None, progress=None): def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model global whisper_align_model
if args.whisper_backend not in WHISPER_BACKENDS: if args.whisper_backend not in WHISPER_BACKENDS:
@ -3662,45 +3631,15 @@ def load_whisper_model(language=None, model_name=None, progress=None):
elif args.whisper_backend == "m-bain/whisperx": elif args.whisper_backend == "m-bain/whisperx":
import whisper, whisperx import whisper, whisperx
device = "cuda" if get_device_name() == "cuda" else "cpu" device = "cuda" if get_device_name() == "cuda" else "cpu"
try: whisper_model = whisperx.load_model(model_name, device)
whisper_model = whisperx.load_model(model_name, device)
except Exception as e:
whisper_model = whisper.load_model(model_name, device)
if not args.hf_token:
print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.")
try:
from pyannote.audio import Inference, Pipeline
whisper_vad = Inference(
"pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=args.hf_token,
device=torch.device(device),
)
# whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
except Exception as e:
pass
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device) whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
print("Loaded Whisper model") print("Loaded Whisper model")
def unload_whisper(): def unload_whisper():
global whisper_model global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model global whisper_align_model
if whisper_vad:
del whisper_vad
whisper_vad = None
if whisper_diarize:
del whisper_diarize
whisper_diarize = None
if whisper_align_model: if whisper_align_model:
del whisper_align_model del whisper_align_model
whisper_align_model = None whisper_align_model = None