From 7e54e897f71e94f33d00327fff7be1a538a72f6e Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 11 Dec 2024 19:57:53 -0600 Subject: [PATCH] also shifted to transformer's pipeline for transcribing --- vall_e/demo.py | 2 +- vall_e/emb/transcribe.py | 167 ++++++++++++++++++++++++++------------- 2 files changed, 111 insertions(+), 58 deletions(-) diff --git a/vall_e/demo.py b/vall_e/demo.py index 9244958..e0e0323 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -134,7 +134,7 @@ def main(): parser.add_argument("--lora", action="store_true") parser.add_argument("--comparison", type=str, default=None) - parser.add_argument("--transcription-model", type=str, default="base") + parser.add_argument("--transcription-model", type=str, default="openai/whisper-base") parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-base-sv") args = parser.parse_args() diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index 5d9ec5f..ef4d56a 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -9,7 +9,14 @@ import argparse import torch import torchaudio -import whisperx +try: + import whisperx +except Exception as e: + whisperx = None + print(f"Error while querying for whisperx: {str(e)}") + pass + +from transformers import pipeline from functools import cache from tqdm.auto import tqdm @@ -17,7 +24,6 @@ from pathlib import Path from ..utils import coerce_dtype - def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -32,7 +38,7 @@ _cached_models = { "align": (None, None), } # yes I can write a decorator to do this -def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"): +def _load_model(model_name="openai/whisper-large-v3", device="cuda", dtype="float16", language="auto", backend="auto", attention="sdpa"): cache_key = f'{model_name}:{device}:{dtype}:{language}' if _cached_models["model"][0] == cache_key: return _cached_models["model"][1] @@ -59,28 +65,57 @@ def _load_model(model_name="large-v3", device="cuda", dtype="float16", language= if language != "auto": kwargs["language"] = language - model = whisperx.load_model(model_name, **kwargs) + if backend == "auto" and whisperx is not None: + backend = "whisperx" + + if backend == "whisperx": + model_name = model_name.replace("openai/whisper-", "") + model = whisperx.load_model(model_name, **kwargs) + else: + model = pipeline( + "automatic-speech-recognition", + model=model_name, + torch_dtype=coerce_dtype(dtype), + device=device, + model_kwargs={"attn_implementation": attention}, + ) _cached_models["model"] = (cache_key, model) return model -def _load_diarization_model(device="cuda"): +def _load_diarization_model(device="cuda", backend="auto"): cache_key = f'{device}' if _cached_models["diarization"][0] == cache_key: return _cached_models["diarization"][1] del _cached_models["diarization"] - model = whisperx.DiarizationPipeline(device=device) + + if backend == "auto" and whisperx is not None: + backend = "whisperx" + + if backend == "whisperx": + model = whisperx.DiarizationPipeline(device=device) + else: + model = None # to do later + _cached_models["diarization"] = (cache_key, model) return model -def _load_align_model(language, device="cuda"): +def _load_align_model(language, device="cuda", backend="auto"): cache_key = f'{language}:{device}' if _cached_models["align"][0] == cache_key: return _cached_models["align"][1] del _cached_models["align"] - model = whisperx.load_align_model(language_code=language, device=device) + + if backend == "auto" and whisperx is not None: + backend = "whisperx" + + if backend == "whisperx": + model = whisperx.load_align_model(language_code=language, device=device) + else: + model = None # to do later + _cached_models["align"] = (cache_key, model) return model @@ -111,6 +146,70 @@ def transcribe( "end": 0, } + # load requested models + model_kwargs["backend"] = "automatic-speech-recognition" + device = model_kwargs.get("device", "cuda") + model = _load_model(language=language, **model_kwargs) + + result = model( + str(audio), + chunk_length_s=30, + batch_size=batch_size, + generate_kwargs={"task": "transcribe", "language": None if language == "auto" else language}, + return_timestamps="word" if align else False, + return_language=True, + ) + + start = 0 + end = 0 + segments = [] + for segment in result["chunks"]: + text = segment["text"] + + if "timestamp" in segment: + s, e = segment["timestamp"] + start = min( start, s ) + end = max( end, e ) + else: + s, e = None, None + + if language == "auto": + language = segment["language"] + + segments.append({ + "start": s, + "end": e, + "text": text, + }) + + if language != "auto": + metadata["language"] = language + + metadata["segments"] = segments + metadata["text"] = result["text"].strip() + metadata["start"] = start + metadata["end"] = end + + return metadata + +# for backwards compat since it also handles some other things for me +def transcribe_whisperx( + audio, + language = "auto", + diarize = False, + batch_size = 16, + verbose=False, + align=True, + **model_kwargs, +): + metadata = { + "segments": [], + "language": "", + "text": "", + "start": 0, + "end": 0, + } + # load requested models device = model_kwargs.get("device", "cuda") model = _load_model(language=language, **model_kwargs) @@ -154,7 +253,7 @@ def transcribe_batch( input_audio = "voices", input_voice = None, output_metadata = "training/metadata", - model_name = "large-v3", + model_name = "openai/whisper-large-v3", skip_existing = True, diarize = False, @@ -178,12 +277,6 @@ def transcribe_batch( if input_voice is not None: only_speakers = [input_voice] - """ - align_model, align_model_metadata, align_model_language = (None, None, None) - model =_load_model(model_name, device, compute_type=dtype) - diarize_model = _load_diarization_model(device=device) if diarize else None - """ - for dataset_name in os.listdir(f'./{input_audio}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): continue @@ -222,47 +315,7 @@ def transcribe_batch( if os.path.isdir(inpath): continue - metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype ) - - """ - metadata[filename] = { - "segments": [], - "language": "", - "text": "", - "start": 0, - "end": 0, - } - - audio = whisperx.load_audio(inpath) - result = model.transcribe(audio, batch_size=batch_size) - language = result["language"] - - if align_model_language != language: - tqdm.write(f'Loading language: {language}') - align_model_language = language - align_model, align_model_metadata = _load_align_model(language=language, device=device) - - result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False) - - metadata[filename]["segments"] = result["segments"] - metadata[filename]["language"] = language - - if diarize_model is not None: - diarize_segments = diarize_model(audio) - result = whisperx.assign_word_speakers(diarize_segments, result) - - text = [] - start = 0 - end = 0 - for segment in result["segments"]: - text.append( segment["text"] ) - start = min( start, segment["start"] ) - end = max( end, segment["end"] ) - - metadata[filename]["text"] = " ".join(text).strip() - metadata[filename]["start"] = start - metadata[filename]["end"] = end - """ + metadata[filename] = transcribe_whisperx( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype ) open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) @@ -273,7 +326,7 @@ def main(): parser.add_argument("--input-voice", type=str, default=None) parser.add_argument("--output-metadata", type=str, default="training/metadata") - parser.add_argument("--model-name", type=str, default="large-v3") + parser.add_argument("--model-name", type=str, default="openai/whisper-large-v3") parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--diarize", action="store_true") parser.add_argument("--batch-size", type=int, default=16)