@ -65,8 +65,7 @@ MAX_TRAINING_DURATION = 11.6097505669
VALLE_ENABLED = False
try :
from vall_e . emb . qnt import encode as quantize
# from vall_e.emb.g2p import encode as phonemize
from vall_e . emb . qnt import encode as valle_quantize
VALLE_ENABLED = True
except Exception as e :
@ -80,9 +79,12 @@ tts = None
tts_loading = False
webui = None
voicefixer = None
whisper_model = None
whisper_vad = None
whisper_diarize = None
whisper_align_model = None
training_state = None
current_voice = None
@ -1165,6 +1167,8 @@ def whisper_transcribe( file, language=None ):
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if not whisper_model :
load_whisper_model ( language = language )
@ -1208,7 +1212,7 @@ def whisper_transcribe( file, language=None ):
else :
result = whisper_model . transcribe ( file )
align_model , metadata = whisper x. load _align_model( language_code = result [ " language " ] , device = device )
align_model , metadata = whisper _align_model
result_aligned = whisperx . align ( result [ " segments " ] , align_model , metadata , file , device )
if whisper_diarize :
@ -1462,9 +1466,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
lines = { ' training ' : [ ] , ' validation ' : [ ] }
segments = { }
if args . tts_backend == " vall-e " :
phonemize = True
for filename in results :
use_segment = use_segments
@ -1541,6 +1542,11 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
' phonemes ' : segment [ ' phonemes ' ] if ' phonemes ' in segment else None
}
jobs = {
' quantize ' : [ [ ] , [ ] ] ,
' phonemize ' : [ [ ] , [ ] ] ,
}
for file in enumerate_progress ( segments , desc = " Parsing segments " , progress = progress ) :
result = segments [ file ]
path = f ' { indir } /audio/ { file } '
@ -1559,9 +1565,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
phonemes = result [ ' phonemes ' ]
if phonemize and phonemes is None :
phonemes = phonemizer ( text , language = lang )
if phonemize :
text = phonemes
normalized = normalizer ( text ) if normalize else text
if len ( text ) > 200 :
@ -1587,7 +1591,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if not culled and audio_length > 0 :
culled = duration < audio_length
line = f ' audio/ { file } | { text} '
line = f ' audio/ { file } | { phonemes if phonemize and phonemes else text} '
lines [ ' training ' if not culled else ' validation ' ] . append ( line )
@ -1596,16 +1600,42 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
os . makedirs ( f ' { indir } /valle/ ' , exist_ok = True )
if not os . path . exists ( f ' { indir } /valle/ { file . replace ( " .wav " , " .qnt.pt " ) } ' ) :
from vall_e . emb . qnt import encode as quantize
quantized = quantize ( waveform , sample_rate ) . cpu ( )
qnt_file = f ' { indir } /valle/ { file . replace ( " .wav " , " .qnt.pt " ) } '
if not os . path . exists ( qnt_file ) :
jobs [ ' quantize ' ] [ 0 ] . append ( qnt_file )
jobs [ ' quantize ' ] [ 1 ] . append ( ( waveform , sample_rate ) )
"""
quantized = valle_quantize ( waveform , sample_rate ) . cpu ( )
torch . save ( quantized , f ' { indir } /valle/ { file . replace ( " .wav " , " .qnt.pt " ) } ' )
print ( " Quantized: " , file )
if not os . path . exists ( f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } ' ) :
from vall_e . emb . g2p import encode as phonemize
phonemized = phonemize ( normalized )
"""
phn_file = f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } '
if not os . path . exists ( phn_file ) :
jobs [ ' phonemize ' ] [ 0 ] . append ( phn_file )
jobs [ ' phonemize ' ] [ 1 ] . append ( normalized )
"""
phonemized = valle_phonemize ( normalized )
open ( f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } ' , ' w ' , encoding = ' utf-8 ' ) . write ( " " . join ( phonemized ) )
print ( " Phonemized: " , file , normalized , text )
"""
for i in enumerate_progress ( range ( len ( jobs [ ' quantize ' ] [ 0 ] ) ) , desc = " Quantizing " , progress = progress ) :
qnt_file = jobs [ ' quantize ' ] [ 0 ] [ i ]
waveform , sample_rate = jobs [ ' quantize ' ] [ 1 ] [ i ]
quantized = valle_quantize ( waveform , sample_rate ) . cpu ( )
torch . save ( quantized , qnt_file )
print ( " Quantized: " , file )
for i in enumerate_progress ( range ( len ( jobs [ ' phonemize ' ] [ 0 ] ) ) , desc = " Phonemizing " , progress = progress ) :
phn_file = jobs [ ' phonemize ' ] [ 0 ] [ i ]
normalized = jobs [ ' phonemize ' ] [ 1 ] [ i ]
phonemized = valle_phonemize ( normalized )
open ( phn_file , ' w ' , encoding = ' utf-8 ' ) . write ( " " . join ( phonemized ) )
print ( " Phonemized: " , file )
training_joined = " \n " . join ( lines [ ' training ' ] )
validation_joined = " \n " . join ( lines [ ' validation ' ] )
@ -2635,6 +2665,7 @@ def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if args . whisper_backend not in WHISPER_BACKENDS :
raise Exception ( f " unavailable backend: { args . whisper_backend } " )
@ -2683,13 +2714,31 @@ def load_whisper_model(language=None, model_name=None, progress=None):
device = torch . device ( device ) ,
)
whisper_diarize = Pipeline . from_pretrained ( " pyannote/speaker-diarization@2.1 " , use_auth_token = args . hf_token )
except Exception as e :
pass
whisper_align_model = whisperx . load_align_model ( model_name = " WAV2VEC2_ASR_LARGE_LV60K_960H " if language == " en " else None , language_code = language , device = device )
print ( " Loaded Whisper model " )
def unload_whisper ( ) :
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if whisper_vad :
del whisper_vad
whisper_vad = None
if whisper_diarize :
del whisper_diarize
whisper_diarize = None
if whisper_align_model :
del whisper_align_model
whisper_align_model = None
if whisper_model :
del whisper_model