begrudgingly added back whisperx integration (VAD/Diarization testing, I really, really need accurate timestamps before dumping mondo amounts of time on training a dataset)

This commit is contained in:
mrq 2023-03-22 19:24:53 +00:00
parent b8c3c4cfe2
commit 4056a27bcb
2 changed files with 72 additions and 5 deletions

View File

@ -6,6 +6,8 @@ if 'TORTOISE_MODELS_DIR' not in os.environ:
if 'TRANSFORMERS_CACHE' not in os.environ:
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
from utils import *
from webui import *

View File

@ -47,7 +47,7 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
TTSES = ['tortoise']
@ -81,6 +81,8 @@ tts_loading = False
webui = None
voicefixer = None
whisper_model = None
whisper_vad = None
whisper_diarize = None
training_state = None
current_voice = None
@ -1131,6 +1133,9 @@ def convert_to_halfp():
def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety
global whisper_model
global whisper_vad
global whisper_diarize
if not whisper_model:
load_whisper_model(language=language)
@ -1156,6 +1161,40 @@ def whisper_transcribe( file, language=None ):
result['segments'].append(reparsed)
return result
if args.whisper_backend == "m-bain/whisperx":
import whisperx
from whisperx.diarize import assign_word_speakers
device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad:
if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize)
else:
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
else:
result = whisper_model.transcribe(file)
align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
if whisper_diarize:
diarize_segments = whisper_diarize(file)
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
# assumes each utterance is single speaker (needs fix)
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
result_aligned["segments"] = result_segments
result_aligned["word_segments"] = word_segments
for i in range(len(result_aligned['segments'])):
del result_aligned['segments'][i]['word-segments']
del result_aligned['segments'][i]['char-segments']
result['segments'] = result_aligned['segments']
return result
def validate_waveform( waveform, sample_rate, min_only=False ):
if not torch.any(waveform < 0):
return "Waveform is empty"
@ -2002,6 +2041,7 @@ def setup_args():
'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it
'voice-fixer-use-cuda': True,
'force-cpu-for-conditioning-latents': False,
'defer-tts-load': False,
'device-override': None,
@ -2013,6 +2053,7 @@ def setup_args():
'output-volume': 1,
'results-folder': "./results/",
'hf-token': None,
'tts-backend': TTSES[0],
'autoregressive-model': None,
@ -2024,6 +2065,7 @@ def setup_args():
'whisper-backend': 'openai/whisper',
'whisper-model': "base",
'whisper-batchsize': 1,
'training-default-halfp': False,
'training-default-bnb': True,
@ -2061,6 +2103,7 @@ def setup_args():
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory")
parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token")
parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.")
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
@ -2072,6 +2115,7 @@ def setup_args():
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
@ -2130,6 +2174,7 @@ def get_default_settings( hypenated=True ):
'output-volume': args.output_volume,
'results-folder': args.results_folder,
'hf-token': args.hf_token,
'tts-backend': args.tts_backend,
'autoregressive-model': args.autoregressive_model,
@ -2141,6 +2186,7 @@ def get_default_settings( hypenated=True ):
'whisper-backend': args.whisper_backend,
'whisper-model': args.whisper_model,
'whisper-batchsize': args.whisper_batchsize,
'training-default-halfp': args.training_default_halfp,
'training-default-bnb': args.training_default_bnb,
@ -2178,6 +2224,7 @@ def update_args( **kwargs ):
args.output_volume = settings['output_volume']
args.results_folder = settings['results_folder']
args.hf_token = settings['hf_token']
args.tts_backend = settings['tts_backend']
args.autoregressive_model = settings['autoregressive_model']
@ -2189,6 +2236,7 @@ def update_args( **kwargs ):
args.whisper_backend = settings['whisper_backend']
args.whisper_model = settings['whisper_model']
args.whisper_batchsize = settings['whisper_batchsize']
args.training_default_halfp = settings['training_default_halfp']
args.training_default_bnb = settings['training_default_bnb']
@ -2529,10 +2577,8 @@ def unload_voicefixer():
def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model
if model_name == "m-bain/whisperx":
print("WhisperX has been removed. Reverting to openai/whisper. Apologies for the inconvenience.")
model_name = "openai/whisper"
global whisper_vad
global whisper_diarize
if args.whisper_backend not in WHISPER_BACKENDS:
raise Exception(f"unavailable backend: {args.whisper_backend}")
@ -2564,6 +2610,25 @@ def load_whisper_model(language=None, model_name=None, progress=None):
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
elif args.whisper_backend == "m-bain/whisperx":
import whisperx
device = "cuda" if get_device_name() == "cuda" else "cpu"
whisper_model = whisperx.load_model(model_name, device)
if not args.hf_token:
print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.")
try:
from pyannote.audio import Inference, Pipeline
whisper_vad = Inference(
"pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=args.hf_token,
device=torch.device(device),
)
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
except Exception as e:
pass
print("Loaded Whisper model")