also shifted to transformer's pipeline for transcribing

This commit is contained in:
mrq 2024-12-11 19:57:53 -06:00
parent b81a98799b
commit 7e54e897f7
2 changed files with 111 additions and 58 deletions

View File

@ -134,7 +134,7 @@ def main():
parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", type=str, default=None)
parser.add_argument("--transcription-model", type=str, default="base")
parser.add_argument("--transcription-model", type=str, default="openai/whisper-base")
parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-base-sv")
args = parser.parse_args()

View File

@ -9,7 +9,14 @@ import argparse
import torch
import torchaudio
import whisperx
try:
import whisperx
except Exception as e:
whisperx = None
print(f"Error while querying for whisperx: {str(e)}")
pass
from transformers import pipeline
from functools import cache
from tqdm.auto import tqdm
@ -17,7 +24,6 @@ from pathlib import Path
from ..utils import coerce_dtype
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -32,7 +38,7 @@ _cached_models = {
"align": (None, None),
}
# yes I can write a decorator to do this
def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"):
def _load_model(model_name="openai/whisper-large-v3", device="cuda", dtype="float16", language="auto", backend="auto", attention="sdpa"):
cache_key = f'{model_name}:{device}:{dtype}:{language}'
if _cached_models["model"][0] == cache_key:
return _cached_models["model"][1]
@ -59,28 +65,57 @@ def _load_model(model_name="large-v3", device="cuda", dtype="float16", language=
if language != "auto":
kwargs["language"] = language
model = whisperx.load_model(model_name, **kwargs)
if backend == "auto" and whisperx is not None:
backend = "whisperx"
if backend == "whisperx":
model_name = model_name.replace("openai/whisper-", "")
model = whisperx.load_model(model_name, **kwargs)
else:
model = pipeline(
"automatic-speech-recognition",
model=model_name,
torch_dtype=coerce_dtype(dtype),
device=device,
model_kwargs={"attn_implementation": attention},
)
_cached_models["model"] = (cache_key, model)
return model
def _load_diarization_model(device="cuda"):
def _load_diarization_model(device="cuda", backend="auto"):
cache_key = f'{device}'
if _cached_models["diarization"][0] == cache_key:
return _cached_models["diarization"][1]
del _cached_models["diarization"]
model = whisperx.DiarizationPipeline(device=device)
if backend == "auto" and whisperx is not None:
backend = "whisperx"
if backend == "whisperx":
model = whisperx.DiarizationPipeline(device=device)
else:
model = None # to do later
_cached_models["diarization"] = (cache_key, model)
return model
def _load_align_model(language, device="cuda"):
def _load_align_model(language, device="cuda", backend="auto"):
cache_key = f'{language}:{device}'
if _cached_models["align"][0] == cache_key:
return _cached_models["align"][1]
del _cached_models["align"]
model = whisperx.load_align_model(language_code=language, device=device)
if backend == "auto" and whisperx is not None:
backend = "whisperx"
if backend == "whisperx":
model = whisperx.load_align_model(language_code=language, device=device)
else:
model = None # to do later
_cached_models["align"] = (cache_key, model)
return model
@ -111,6 +146,70 @@ def transcribe(
"end": 0,
}
# load requested models
model_kwargs["backend"] = "automatic-speech-recognition"
device = model_kwargs.get("device", "cuda")
model = _load_model(language=language, **model_kwargs)
result = model(
str(audio),
chunk_length_s=30,
batch_size=batch_size,
generate_kwargs={"task": "transcribe", "language": None if language == "auto" else language},
return_timestamps="word" if align else False,
return_language=True,
)
start = 0
end = 0
segments = []
for segment in result["chunks"]:
text = segment["text"]
if "timestamp" in segment:
s, e = segment["timestamp"]
start = min( start, s )
end = max( end, e )
else:
s, e = None, None
if language == "auto":
language = segment["language"]
segments.append({
"start": s,
"end": e,
"text": text,
})
if language != "auto":
metadata["language"] = language
metadata["segments"] = segments
metadata["text"] = result["text"].strip()
metadata["start"] = start
metadata["end"] = end
return metadata
# for backwards compat since it also handles some other things for me
def transcribe_whisperx(
audio,
language = "auto",
diarize = False,
batch_size = 16,
verbose=False,
align=True,
**model_kwargs,
):
metadata = {
"segments": [],
"language": "",
"text": "",
"start": 0,
"end": 0,
}
# load requested models
device = model_kwargs.get("device", "cuda")
model = _load_model(language=language, **model_kwargs)
@ -154,7 +253,7 @@ def transcribe_batch(
input_audio = "voices",
input_voice = None,
output_metadata = "training/metadata",
model_name = "large-v3",
model_name = "openai/whisper-large-v3",
skip_existing = True,
diarize = False,
@ -178,12 +277,6 @@ def transcribe_batch(
if input_voice is not None:
only_speakers = [input_voice]
"""
align_model, align_model_metadata, align_model_language = (None, None, None)
model =_load_model(model_name, device, compute_type=dtype)
diarize_model = _load_diarization_model(device=device) if diarize else None
"""
for dataset_name in os.listdir(f'./{input_audio}/'):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue
@ -222,47 +315,7 @@ def transcribe_batch(
if os.path.isdir(inpath):
continue
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
"""
metadata[filename] = {
"segments": [],
"language": "",
"text": "",
"start": 0,
"end": 0,
}
audio = whisperx.load_audio(inpath)
result = model.transcribe(audio, batch_size=batch_size)
language = result["language"]
if align_model_language != language:
tqdm.write(f'Loading language: {language}')
align_model_language = language
align_model, align_model_metadata = _load_align_model(language=language, device=device)
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
metadata[filename]["segments"] = result["segments"]
metadata[filename]["language"] = language
if diarize_model is not None:
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
text = []
start = 0
end = 0
for segment in result["segments"]:
text.append( segment["text"] )
start = min( start, segment["start"] )
end = max( end, segment["end"] )
metadata[filename]["text"] = " ".join(text).strip()
metadata[filename]["start"] = start
metadata[filename]["end"] = end
"""
metadata[filename] = transcribe_whisperx( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
@ -273,7 +326,7 @@ def main():
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--output-metadata", type=str, default="training/metadata")
parser.add_argument("--model-name", type=str, default="large-v3")
parser.add_argument("--model-name", type=str, default="openai/whisper-large-v3")
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--diarize", action="store_true")
parser.add_argument("--batch-size", type=int, default=16)