forked from mrq/ai-voice-cloning
updated whisperX integration for use with the latest version (v3) (NOTE: you WILL need to also update whisperx if you pull this commit)
This commit is contained in:
parent
805d7d35e8
commit
e227ab8e08
65
src/utils.py
65
src/utils.py
|
@ -177,8 +177,6 @@ webui = None
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
|
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
whisper_vad = None
|
|
||||||
whisper_diarize = None
|
|
||||||
whisper_align_model = None
|
whisper_align_model = None
|
||||||
|
|
||||||
training_state = None
|
training_state = None
|
||||||
|
@ -1999,8 +1997,6 @@ def whisper_sanitize( results ):
|
||||||
def whisper_transcribe( file, language=None ):
|
def whisper_transcribe( file, language=None ):
|
||||||
# shouldn't happen, but it's for safety
|
# shouldn't happen, but it's for safety
|
||||||
global whisper_model
|
global whisper_model
|
||||||
global whisper_vad
|
|
||||||
global whisper_diarize
|
|
||||||
global whisper_align_model
|
global whisper_align_model
|
||||||
|
|
||||||
if not whisper_model:
|
if not whisper_model:
|
||||||
|
@ -2035,37 +2031,12 @@ def whisper_transcribe( file, language=None ):
|
||||||
|
|
||||||
if args.whisper_backend == "m-bain/whisperx":
|
if args.whisper_backend == "m-bain/whisperx":
|
||||||
import whisperx
|
import whisperx
|
||||||
from whisperx.diarize import assign_word_speakers
|
|
||||||
|
|
||||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||||
if whisper_vad:
|
result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize)
|
||||||
# omits a considerable amount of the end
|
|
||||||
if args.whisper_batchsize > 1:
|
|
||||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
|
||||||
else:
|
|
||||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
|
||||||
"""
|
|
||||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
result = whisper_model.transcribe(file)
|
|
||||||
|
|
||||||
align_model, metadata = whisper_align_model
|
align_model, metadata = whisper_align_model
|
||||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False)
|
||||||
|
|
||||||
if whisper_diarize:
|
|
||||||
diarize_segments = whisper_diarize(file)
|
|
||||||
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
|
|
||||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
|
||||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
|
||||||
# assumes each utterance is single speaker (needs fix)
|
|
||||||
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
|
|
||||||
result_aligned["segments"] = result_segments
|
|
||||||
result_aligned["word_segments"] = word_segments
|
|
||||||
|
|
||||||
for i in range(len(result_aligned['segments'])):
|
|
||||||
del result_aligned['segments'][i]['word-segments']
|
|
||||||
del result_aligned['segments'][i]['char-segments']
|
|
||||||
|
|
||||||
result['segments'] = result_aligned['segments']
|
result['segments'] = result_aligned['segments']
|
||||||
result['text'] = []
|
result['text'] = []
|
||||||
|
@ -3625,8 +3596,6 @@ def unload_voicefixer():
|
||||||
|
|
||||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
global whisper_vad
|
|
||||||
global whisper_diarize
|
|
||||||
global whisper_align_model
|
global whisper_align_model
|
||||||
|
|
||||||
if args.whisper_backend not in WHISPER_BACKENDS:
|
if args.whisper_backend not in WHISPER_BACKENDS:
|
||||||
|
@ -3662,45 +3631,15 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
elif args.whisper_backend == "m-bain/whisperx":
|
elif args.whisper_backend == "m-bain/whisperx":
|
||||||
import whisper, whisperx
|
import whisper, whisperx
|
||||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||||
try:
|
|
||||||
whisper_model = whisperx.load_model(model_name, device)
|
whisper_model = whisperx.load_model(model_name, device)
|
||||||
except Exception as e:
|
|
||||||
whisper_model = whisper.load_model(model_name, device)
|
|
||||||
|
|
||||||
if not args.hf_token:
|
|
||||||
print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pyannote.audio import Inference, Pipeline
|
|
||||||
whisper_vad = Inference(
|
|
||||||
"pyannote/segmentation",
|
|
||||||
pre_aggregation_hook=lambda segmentation: segmentation,
|
|
||||||
use_auth_token=args.hf_token,
|
|
||||||
device=torch.device(device),
|
|
||||||
)
|
|
||||||
# whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
|
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
|
||||||
|
|
||||||
print("Loaded Whisper model")
|
print("Loaded Whisper model")
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
global whisper_model
|
global whisper_model
|
||||||
global whisper_vad
|
|
||||||
global whisper_diarize
|
|
||||||
global whisper_align_model
|
global whisper_align_model
|
||||||
|
|
||||||
if whisper_vad:
|
|
||||||
del whisper_vad
|
|
||||||
whisper_vad = None
|
|
||||||
|
|
||||||
if whisper_diarize:
|
|
||||||
del whisper_diarize
|
|
||||||
whisper_diarize = None
|
|
||||||
|
|
||||||
if whisper_align_model:
|
if whisper_align_model:
|
||||||
del whisper_align_model
|
del whisper_align_model
|
||||||
whisper_align_model = None
|
whisper_align_model = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user