forked from mrq/ai-voice-cloning
updated whisperX integration for use with the latest version (v3) (NOTE: you WILL need to also update whisperx if you pull this commit)
This commit is contained in:
parent
805d7d35e8
commit
e227ab8e08
65
src/utils.py
65
src/utils.py
|
@ -177,8 +177,6 @@ webui = None
|
|||
voicefixer = None
|
||||
|
||||
whisper_model = None
|
||||
whisper_vad = None
|
||||
whisper_diarize = None
|
||||
whisper_align_model = None
|
||||
|
||||
training_state = None
|
||||
|
@ -1999,8 +1997,6 @@ def whisper_sanitize( results ):
|
|||
def whisper_transcribe( file, language=None ):
|
||||
# shouldn't happen, but it's for safety
|
||||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if not whisper_model:
|
||||
|
@ -2035,37 +2031,12 @@ def whisper_transcribe( file, language=None ):
|
|||
|
||||
if args.whisper_backend == "m-bain/whisperx":
|
||||
import whisperx
|
||||
from whisperx.diarize import assign_word_speakers
|
||||
|
||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
if whisper_vad:
|
||||
# omits a considerable amount of the end
|
||||
if args.whisper_batchsize > 1:
|
||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||
else:
|
||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||
"""
|
||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||
"""
|
||||
else:
|
||||
result = whisper_model.transcribe(file)
|
||||
result = whisper_model.transcribe(file, batch_size=args.whisper_batchsize)
|
||||
|
||||
align_model, metadata = whisper_align_model
|
||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
||||
|
||||
if whisper_diarize:
|
||||
diarize_segments = whisper_diarize(file)
|
||||
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
# assumes each utterance is single speaker (needs fix)
|
||||
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
|
||||
result_aligned["segments"] = result_segments
|
||||
result_aligned["word_segments"] = word_segments
|
||||
|
||||
for i in range(len(result_aligned['segments'])):
|
||||
del result_aligned['segments'][i]['word-segments']
|
||||
del result_aligned['segments'][i]['char-segments']
|
||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device, return_char_alignments=False)
|
||||
|
||||
result['segments'] = result_aligned['segments']
|
||||
result['text'] = []
|
||||
|
@ -3625,8 +3596,6 @@ def unload_voicefixer():
|
|||
|
||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
||||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if args.whisper_backend not in WHISPER_BACKENDS:
|
||||
|
@ -3662,45 +3631,15 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
|||
elif args.whisper_backend == "m-bain/whisperx":
|
||||
import whisper, whisperx
|
||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
try:
|
||||
whisper_model = whisperx.load_model(model_name, device)
|
||||
except Exception as e:
|
||||
whisper_model = whisper.load_model(model_name, device)
|
||||
|
||||
if not args.hf_token:
|
||||
print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.")
|
||||
|
||||
try:
|
||||
from pyannote.audio import Inference, Pipeline
|
||||
whisper_vad = Inference(
|
||||
"pyannote/segmentation",
|
||||
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||
use_auth_token=args.hf_token,
|
||||
device=torch.device(device),
|
||||
)
|
||||
# whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
|
||||
|
||||
print("Loaded Whisper model")
|
||||
|
||||
def unload_whisper():
|
||||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if whisper_vad:
|
||||
del whisper_vad
|
||||
whisper_vad = None
|
||||
|
||||
if whisper_diarize:
|
||||
del whisper_diarize
|
||||
whisper_diarize = None
|
||||
|
||||
if whisper_align_model:
|
||||
del whisper_align_model
|
||||
whisper_align_model = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user