forked from mrq/ai-voice-cloning
remove redundant phonemize for vall-e (oops), quantize all files and then phonemize all files for cope optimization, load alignment model once instead of for every transcription (speedup with whisperx)
This commit is contained in:
parent
19c0854e6a
commit
d2a9ab9e41
|
@ -1,7 +1,3 @@
|
|||
data_root: ./training/${voice}/
|
||||
ckpt_root: ./training/${voice}/ckpt/
|
||||
log_root: ./training/${voice}/logs/
|
||||
|
||||
data_dirs: [./training/${voice}/valle/]
|
||||
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
||||
|
||||
|
|
79
src/utils.py
79
src/utils.py
|
@ -65,8 +65,7 @@ MAX_TRAINING_DURATION = 11.6097505669
|
|||
VALLE_ENABLED = False
|
||||
|
||||
try:
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
# from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as valle_quantize
|
||||
|
||||
VALLE_ENABLED = True
|
||||
except Exception as e:
|
||||
|
@ -80,9 +79,12 @@ tts = None
|
|||
tts_loading = False
|
||||
webui = None
|
||||
voicefixer = None
|
||||
|
||||
whisper_model = None
|
||||
whisper_vad = None
|
||||
whisper_diarize = None
|
||||
whisper_align_model = None
|
||||
|
||||
training_state = None
|
||||
|
||||
current_voice = None
|
||||
|
@ -1165,6 +1167,8 @@ def whisper_transcribe( file, language=None ):
|
|||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if not whisper_model:
|
||||
load_whisper_model(language=language)
|
||||
|
||||
|
@ -1208,7 +1212,7 @@ def whisper_transcribe( file, language=None ):
|
|||
else:
|
||||
result = whisper_model.transcribe(file)
|
||||
|
||||
align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
align_model, metadata = whisper_align_model
|
||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
||||
|
||||
if whisper_diarize:
|
||||
|
@ -1462,9 +1466,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
lines = { 'training': [], 'validation': [] }
|
||||
segments = {}
|
||||
|
||||
if args.tts_backend == "vall-e":
|
||||
phonemize = True
|
||||
|
||||
for filename in results:
|
||||
use_segment = use_segments
|
||||
|
||||
|
@ -1541,6 +1542,11 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
|
||||
}
|
||||
|
||||
jobs = {
|
||||
'quantize': [[], []],
|
||||
'phonemize': [[], []],
|
||||
}
|
||||
|
||||
for file in enumerate_progress(segments, desc="Parsing segments", progress=progress):
|
||||
result = segments[file]
|
||||
path = f'{indir}/audio/{file}'
|
||||
|
@ -1559,8 +1565,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
phonemes = result['phonemes']
|
||||
if phonemize and phonemes is None:
|
||||
phonemes = phonemizer( text, language=lang )
|
||||
if phonemize:
|
||||
text = phonemes
|
||||
|
||||
normalized = normalizer(text) if normalize else text
|
||||
|
||||
|
@ -1587,7 +1591,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
if not culled and audio_length > 0:
|
||||
culled = duration < audio_length
|
||||
|
||||
line = f'audio/{file}|{text}'
|
||||
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
|
||||
|
||||
lines['training' if not culled else 'validation'].append(line)
|
||||
|
||||
|
@ -1596,16 +1600,42 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
|
||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||
|
||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
quantized = quantize( waveform, sample_rate ).cpu()
|
||||
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
|
||||
if not os.path.exists(qnt_file):
|
||||
jobs['quantize'][0].append(qnt_file)
|
||||
jobs['quantize'][1].append((waveform, sample_rate))
|
||||
"""
|
||||
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||
print("Quantized:", file)
|
||||
"""
|
||||
|
||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
phonemized = phonemize( normalized )
|
||||
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
|
||||
if not os.path.exists(phn_file):
|
||||
jobs['phonemize'][0].append(phn_file)
|
||||
jobs['phonemize'][1].append(normalized)
|
||||
"""
|
||||
phonemized = valle_phonemize( normalized )
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||
print("Phonemized:", file, normalized, text)
|
||||
"""
|
||||
|
||||
for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress):
|
||||
qnt_file = jobs['quantize'][0][i]
|
||||
waveform, sample_rate = jobs['quantize'][1][i]
|
||||
|
||||
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, qnt_file)
|
||||
print("Quantized:", file)
|
||||
|
||||
for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress):
|
||||
phn_file = jobs['phonemize'][0][i]
|
||||
normalized = jobs['phonemize'][1][i]
|
||||
|
||||
phonemized = valle_phonemize( normalized )
|
||||
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||
print("Phonemized:", file)
|
||||
|
||||
|
||||
training_joined = "\n".join(lines['training'])
|
||||
validation_joined = "\n".join(lines['validation'])
|
||||
|
@ -2635,6 +2665,7 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
|||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if args.whisper_backend not in WHISPER_BACKENDS:
|
||||
raise Exception(f"unavailable backend: {args.whisper_backend}")
|
||||
|
@ -2683,13 +2714,31 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
|||
device=torch.device(device),
|
||||
)
|
||||
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
|
||||
|
||||
print("Loaded Whisper model")
|
||||
|
||||
def unload_whisper():
|
||||
global whisper_model
|
||||
global whisper_vad
|
||||
global whisper_diarize
|
||||
global whisper_align_model
|
||||
|
||||
if whisper_vad:
|
||||
del whisper_vad
|
||||
whisper_vad = None
|
||||
|
||||
if whisper_diarize:
|
||||
del whisper_diarize
|
||||
whisper_diarize = None
|
||||
|
||||
if whisper_align_model:
|
||||
del whisper_align_model
|
||||
whisper_align_model = None
|
||||
|
||||
if whisper_model:
|
||||
del whisper_model
|
||||
|
|
Loading…
Reference in New Issue
Block a user