From e227ab8e08f6cc17a466f7bc0530773263acf04b Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 9 Jun 2023 02:41:29 +0000 Subject: [PATCH] updated whisperX integration for use with the latest version (v3) (NOTE: you WILL need to also update whisperx if you pull this commit) --- src/utils.py | 67 +++------------------------------------------------- 1 file changed, 3 insertions(+), 64 deletions(-) diff --git a/src/utils.py b/src/utils.py index 9c2c621..a204011 100755 --- a/src/utils.py +++ b/src/utils.py @@ -177,8 +177,6 @@ webui = None voicefixer = None whisper_model = None -whisper_vad = None -whisper_diarize = None whisper_align_model = None training_state = None @@ -1999,8 +1997,6 @@ def whisper_sanitize( results ): def whisper_transcribe( file, language=None ): # shouldn't happen, but it's for safety global whisper_model - global whisper_vad - global whisper_diarize global whisper_align_model if not whisper_model: @@ -2035,37 +2031,12 @@ def whisper_transcribe( file, language=None ): if args.whisper_backend == "m-bain/whisperx": import whisperx - from whisperx.diarize import assign_word_speakers device = "cuda" if get_device_name() == "cuda" else "cpu" - if whisper_vad: - # omits a considerable amount of the end - if args.whisper_batchsize > 1: - result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") - else: - result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) - """ - result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) - """ - else: - result = whisper_model.transcribe(file) + result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize) align_model, metadata = whisper_align_model - result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) - - if whisper_diarize: - diarize_segments = whisper_diarize(file) - diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True)) - diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) - diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) - # assumes each utterance is single speaker (needs fix) - result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True) - result_aligned["segments"] = result_segments - result_aligned["word_segments"] = word_segments - - for i in range(len(result_aligned['segments'])): - del result_aligned['segments'][i]['word-segments'] - del result_aligned['segments'][i]['char-segments'] + result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False) result['segments'] = result_aligned['segments'] result['text'] = [] @@ -3625,8 +3596,6 @@ def unload_voicefixer(): def load_whisper_model(language=None, model_name=None, progress=None): global whisper_model - global whisper_vad - global whisper_diarize global whisper_align_model if args.whisper_backend not in WHISPER_BACKENDS: @@ -3662,45 +3631,15 @@ def load_whisper_model(language=None, model_name=None, progress=None): elif args.whisper_backend == "m-bain/whisperx": import whisper, whisperx device = "cuda" if get_device_name() == "cuda" else "cpu" - try: - whisper_model = whisperx.load_model(model_name, device) - except Exception as e: - whisper_model = whisper.load_model(model_name, device) - - if not args.hf_token: - print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.") - - try: - from pyannote.audio import Inference, Pipeline - whisper_vad = Inference( - "pyannote/segmentation", - pre_aggregation_hook=lambda segmentation: segmentation, - use_auth_token=args.hf_token, - device=torch.device(device), - ) - # whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) - - except Exception as e: - pass - + whisper_model = whisperx.load_model(model_name, device) whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device) print("Loaded Whisper model") def unload_whisper(): global whisper_model - global whisper_vad - global whisper_diarize global whisper_align_model - if whisper_vad: - del whisper_vad - whisper_vad = None - - if whisper_diarize: - del whisper_diarize - whisper_diarize = None - if whisper_align_model: del whisper_align_model whisper_align_model = None