forked from mrq/ai-voice-cloning
begrudgingly added back whisperx integration (VAD/Diarization testing, I really, really need accurate timestamps before dumping mondo amounts of time on training a dataset)
This commit is contained in:
parent
b8c3c4cfe2
commit
4056a27bcb
|
@ -6,6 +6,8 @@ if 'TORTOISE_MODELS_DIR' not in os.environ:
|
|||
if 'TRANSFORMERS_CACHE' not in os.environ:
|
||||
os.environ['TRANSFORMERS_CACHE'] = os.path.realpath(os.path.join(os.getcwd(), './models/transformers/'))
|
||||
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
|
||||
from utils import *
|
||||
from webui import *
|
||||
|
||||
|
|
75
src/utils.py
75
src/utils.py
|
@ -47,7 +47,7 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
|
|||
|
||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
|
||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
||||
TTSES = ['tortoise']
|
||||
|
||||
|
@ -81,6 +81,8 @@ tts_loading = False
|
|||
webui = None
|
||||
voicefixer = None
|
||||
whisper_model = None
|
||||
whisper_vad = None
|
||||
whisper_diarize = None
|
||||
training_state = None
|
||||
|
||||
current_voice = None
|
||||
|
@ -1131,6 +1133,9 @@ def convert_to_halfp():
|
|||
|
||||
def whisper_transcribe( file, language=None ):
|
||||
# shouldn't happen, but it's for safety
|
||||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
if not whisper_model:
|
||||
load_whisper_model(language=language)
|
||||
|
||||
|
@ -1156,6 +1161,40 @@ def whisper_transcribe( file, language=None ):
|
|||
result['segments'].append(reparsed)
|
||||
return result
|
||||
|
||||
if args.whisper_backend == "m-bain/whisperx":
|
||||
import whisperx
|
||||
from whisperx.diarize import assign_word_speakers
|
||||
|
||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
if whisper_vad:
|
||||
if args.whisper_batchsize > 1:
|
||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize)
|
||||
else:
|
||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||
else:
|
||||
result = whisper_model.transcribe(file)
|
||||
|
||||
align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
||||
|
||||
if whisper_diarize:
|
||||
diarize_segments = whisper_diarize(file)
|
||||
diarize_df = pd.DataFrame(diarize_segments.itertracks(yield_label=True))
|
||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||
# assumes each utterance is single speaker (needs fix)
|
||||
result_segments, word_segments = assign_word_speakers(diarize_df, result_aligned["segments"], fill_nearest=True)
|
||||
result_aligned["segments"] = result_segments
|
||||
result_aligned["word_segments"] = word_segments
|
||||
|
||||
for i in range(len(result_aligned['segments'])):
|
||||
del result_aligned['segments'][i]['word-segments']
|
||||
del result_aligned['segments'][i]['char-segments']
|
||||
|
||||
result['segments'] = result_aligned['segments']
|
||||
|
||||
return result
|
||||
|
||||
def validate_waveform( waveform, sample_rate, min_only=False ):
|
||||
if not torch.any(waveform < 0):
|
||||
return "Waveform is empty"
|
||||
|
@ -2001,6 +2040,7 @@ def setup_args():
|
|||
'latents-lean-and-mean': True,
|
||||
'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it
|
||||
'voice-fixer-use-cuda': True,
|
||||
|
||||
|
||||
'force-cpu-for-conditioning-latents': False,
|
||||
'defer-tts-load': False,
|
||||
|
@ -2013,6 +2053,7 @@ def setup_args():
|
|||
'output-volume': 1,
|
||||
'results-folder': "./results/",
|
||||
|
||||
'hf-token': None,
|
||||
'tts-backend': TTSES[0],
|
||||
|
||||
'autoregressive-model': None,
|
||||
|
@ -2024,6 +2065,7 @@ def setup_args():
|
|||
|
||||
'whisper-backend': 'openai/whisper',
|
||||
'whisper-model': "base",
|
||||
'whisper-batchsize': 1,
|
||||
|
||||
'training-default-halfp': False,
|
||||
'training-default-bnb': True,
|
||||
|
@ -2061,6 +2103,7 @@ def setup_args():
|
|||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
||||
parser.add_argument("--results-folder", type=str, default=default_arguments['results-folder'], help="Sets output directory")
|
||||
|
||||
parser.add_argument("--hf-token", type=str, default=default_arguments['hf-token'], help="HuggingFace Token")
|
||||
parser.add_argument("--tts-backend", default=default_arguments['tts-backend'], help="Specifies which TTS backend to use.")
|
||||
|
||||
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
|
||||
|
@ -2072,6 +2115,7 @@ def setup_args():
|
|||
|
||||
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
|
||||
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
||||
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
|
||||
|
||||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||
|
@ -2130,6 +2174,7 @@ def get_default_settings( hypenated=True ):
|
|||
'output-volume': args.output_volume,
|
||||
'results-folder': args.results_folder,
|
||||
|
||||
'hf-token': args.hf_token,
|
||||
'tts-backend': args.tts_backend,
|
||||
|
||||
'autoregressive-model': args.autoregressive_model,
|
||||
|
@ -2141,6 +2186,7 @@ def get_default_settings( hypenated=True ):
|
|||
|
||||
'whisper-backend': args.whisper_backend,
|
||||
'whisper-model': args.whisper_model,
|
||||
'whisper-batchsize': args.whisper_batchsize,
|
||||
|
||||
'training-default-halfp': args.training_default_halfp,
|
||||
'training-default-bnb': args.training_default_bnb,
|
||||
|
@ -2178,6 +2224,7 @@ def update_args( **kwargs ):
|
|||
args.output_volume = settings['output_volume']
|
||||
args.results_folder = settings['results_folder']
|
||||
|
||||
args.hf_token = settings['hf_token']
|
||||
args.tts_backend = settings['tts_backend']
|
||||
|
||||
args.autoregressive_model = settings['autoregressive_model']
|
||||
|
@ -2189,6 +2236,7 @@ def update_args( **kwargs ):
|
|||
|
||||
args.whisper_backend = settings['whisper_backend']
|
||||
args.whisper_model = settings['whisper_model']
|
||||
args.whisper_batchsize = settings['whisper_batchsize']
|
||||
|
||||
args.training_default_halfp = settings['training_default_halfp']
|
||||
args.training_default_bnb = settings['training_default_bnb']
|
||||
|
@ -2529,10 +2577,8 @@ def unload_voicefixer():
|
|||
|
||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
||||
global whisper_model
|
||||
|
||||
if model_name == "m-bain/whisperx":
|
||||
print("WhisperX has been removed. Reverting to openai/whisper. Apologies for the inconvenience.")
|
||||
model_name = "openai/whisper"
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
|
||||
if args.whisper_backend not in WHISPER_BACKENDS:
|
||||
raise Exception(f"unavailable backend: {args.whisper_backend}")
|
||||
|
@ -2564,6 +2610,25 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
|||
|
||||
b_lang = language.encode('ascii')
|
||||
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
|
||||
elif args.whisper_backend == "m-bain/whisperx":
|
||||
import whisperx
|
||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
whisper_model = whisperx.load_model(model_name, device)
|
||||
|
||||
if not args.hf_token:
|
||||
print("No huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model.")
|
||||
|
||||
try:
|
||||
from pyannote.audio import Inference, Pipeline
|
||||
whisper_vad = Inference(
|
||||
"pyannote/segmentation",
|
||||
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||
use_auth_token=args.hf_token,
|
||||
device=torch.device(device),
|
||||
)
|
||||
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
print("Loaded Whisper model")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user