From 4056a27bcbb2aaf01daf4b0552353d0adb672302 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 22 Mar 2023 19:24:53 +0000 Subject: [PATCH] begrudgingly added back whisperx integration (VAD/Diarization testing, I really, really need accurate timestamps before dumping mondo amounts of time on training a dataset) --- src/main.py | 2 ++ src/utils.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index d2d3c27..cfde26b 100755 --- a/src/main.py +++ b/src/main.py @@ -6,6 +6,8 @@ if 'TORTOISE_MODELS_DIR' not in os.environ: if 'TRANSFORMERS_CACHE' not in os.environ: os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/')) +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + from utils import * from webui import * diff --git a/src/utils.py b/src/utils.py index 4bf1ea5..76c8345 100755 --- a/src/utils.py +++ b/src/utils.py @@ -47,7 +47,7 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370 WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] -WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"] +WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] TTSES = ['tortoise'] @@ -81,6 +81,8 @@ tts_loading = False webui = None voicefixer = None whisper_model = None +whisper_vad = None +whisper_diarize = None training_state = None current_voice = None @@ -1131,6 +1133,9 @@ def convert_to_halfp(): def whisper_transcribe( file, language=None ): # shouldn't happen, but it's for safety + global whisper_model + global whisper_vad + global whisper_diarize if not whisper_model: load_whisper_model(language=language) @@ -1156,6 +1161,40 @@ def whisper_transcribe( file, language=None ): result['segments'].append(reparsed) return result + if args.whisper_backend == "m-bain/whisperx": + import whisperx + from whisperx.diarize import assign_word_speakers + + device = "cuda" if get_device_name() == "cuda" else "cpu" + if whisper_vad: + if args.whisper_batchsize > 1: + result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize) + else: + result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) + else: + result = whisper_model.transcribe(file) + + align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device) + result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) + + if whisper_diarize: + diarize_segments = whisper_diarize(file) + diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True)) + diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) + diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) + # assumes each utterance is single speaker (needs fix) + result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True) + result_aligned["segments"] = result_segments + result_aligned["word_segments"] = word_segments + + for i in range(len(result_aligned['segments'])): + del result_aligned['segments'][i]['word-segments'] + del result_aligned['segments'][i]['char-segments'] + + result['segments'] = result_aligned['segments'] + + return result + def validate_waveform( waveform, sample_rate, min_only=False ): if not torch.any(waveform < 0): return "Waveform is empty" @@ -2001,6 +2040,7 @@ def setup_args(): 'latents-lean-and-mean': True, 'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it 'voice-fixer-use-cuda': True, + 'force-cpu-for-conditioning-latents': False, 'defer-tts-load': False, @@ -2013,6 +2053,7 @@ def setup_args(): 'output-volume': 1, 'results-folder': "./results/", + 'hf-token': None, 'tts-backend': TTSES[0], 'autoregressive-model': None, @@ -2024,6 +2065,7 @@ def setup_args(): 'whisper-backend': 'openai/whisper', 'whisper-model': "base", + 'whisper-batchsize': 1, 'training-default-halfp': False, 'training-default-bnb': True, @@ -2061,6 +2103,7 @@ def setup_args(): parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory") + parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token") parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.") parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") @@ -2072,6 +2115,7 @@ def setup_args(): parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") + parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX") parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") @@ -2130,6 +2174,7 @@ def get_default_settings( hypenated=True ): 'output-volume': args.output_volume, 'results-folder': args.results_folder, + 'hf-token': args.hf_token, 'tts-backend': args.tts_backend, 'autoregressive-model': args.autoregressive_model, @@ -2141,6 +2186,7 @@ def get_default_settings( hypenated=True ): 'whisper-backend': args.whisper_backend, 'whisper-model': args.whisper_model, + 'whisper-batchsize': args.whisper_batchsize, 'training-default-halfp': args.training_default_halfp, 'training-default-bnb': args.training_default_bnb, @@ -2178,6 +2224,7 @@ def update_args( **kwargs ): args.output_volume = settings['output_volume'] args.results_folder = settings['results_folder'] + args.hf_token = settings['hf_token'] args.tts_backend = settings['tts_backend'] args.autoregressive_model = settings['autoregressive_model'] @@ -2189,6 +2236,7 @@ def update_args( **kwargs ): args.whisper_backend = settings['whisper_backend'] args.whisper_model = settings['whisper_model'] + args.whisper_batchsize = settings['whisper_batchsize'] args.training_default_halfp = settings['training_default_halfp'] args.training_default_bnb = settings['training_default_bnb'] @@ -2529,10 +2577,8 @@ def unload_voicefixer(): def load_whisper_model(language=None, model_name=None, progress=None): global whisper_model - - if model_name == "m-bain/whisperx": - print("WhisperX has been removed. Reverting to openai/whisper. Apologies for the inconvenience.") - model_name = "openai/whisper" + global whisper_vad + global whisper_diarize if args.whisper_backend not in WHISPER_BACKENDS: raise Exception(f"unavailable backend: {args.whisper_backend}") @@ -2564,6 +2610,25 @@ def load_whisper_model(language=None, model_name=None, progress=None): b_lang = language.encode('ascii') whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang) + elif args.whisper_backend == "m-bain/whisperx": + import whisperx + device = "cuda" if get_device_name() == "cuda" else "cpu" + whisper_model = whisperx.load_model(model_name, device) + + if not args.hf_token: + print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.") + + try: + from pyannote.audio import Inference, Pipeline + whisper_vad = Inference( + "pyannote/segmentation", + pre_aggregation_hook=lambda segmentation: segmentation, + use_auth_token=args.hf_token, + device=torch.device(device), + ) + whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) + except Exception as e: + pass print("Loaded Whisper model")