forked from mrq/ai-voice-cloning
remove redundant phonemize for vall-e (oops), quantize all files and then phonemize all files for cope optimization, load alignment model once instead of for every transcription (speedup with whisperx)
This commit is contained in:
parent
19c0854e6a
commit
d2a9ab9e41
|
@ -1,7 +1,3 @@
|
||||||
data_root: ./training/${voice}/
|
|
||||||
ckpt_root: ./training/${voice}/ckpt/
|
|
||||||
log_root: ./training/${voice}/logs/
|
|
||||||
|
|
||||||
data_dirs: [./training/${voice}/valle/]
|
data_dirs: [./training/${voice}/valle/]
|
||||||
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
||||||
|
|
||||||
|
|
81
src/utils.py
81
src/utils.py
|
@ -65,8 +65,7 @@ MAX_TRAINING_DURATION = 11.6097505669
|
||||||
VALLE_ENABLED = False
|
VALLE_ENABLED = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vall_e.emb.qnt import encode as quantize
|
from vall_e.emb.qnt import encode as valle_quantize
|
||||||
# from vall_e.emb.g2p import encode as phonemize
|
|
||||||
|
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -80,9 +79,12 @@ tts = None
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
webui = None
|
webui = None
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
|
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
whisper_vad = None
|
whisper_vad = None
|
||||||
whisper_diarize = None
|
whisper_diarize = None
|
||||||
|
whisper_align_model = None
|
||||||
|
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
current_voice = None
|
current_voice = None
|
||||||
|
@ -1165,6 +1167,8 @@ def whisper_transcribe( file, language=None ):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
global whisper_vad
|
global whisper_vad
|
||||||
global whisper_diarize
|
global whisper_diarize
|
||||||
|
global whisper_align_model
|
||||||
|
|
||||||
if not whisper_model:
|
if not whisper_model:
|
||||||
load_whisper_model(language=language)
|
load_whisper_model(language=language)
|
||||||
|
|
||||||
|
@ -1208,7 +1212,7 @@ def whisper_transcribe( file, language=None ):
|
||||||
else:
|
else:
|
||||||
result = whisper_model.transcribe(file)
|
result = whisper_model.transcribe(file)
|
||||||
|
|
||||||
align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
align_model, metadata = whisper_align_model
|
||||||
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
|
||||||
|
|
||||||
if whisper_diarize:
|
if whisper_diarize:
|
||||||
|
@ -1462,9 +1466,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
lines = { 'training': [], 'validation': [] }
|
lines = { 'training': [], 'validation': [] }
|
||||||
segments = {}
|
segments = {}
|
||||||
|
|
||||||
if args.tts_backend == "vall-e":
|
|
||||||
phonemize = True
|
|
||||||
|
|
||||||
for filename in results:
|
for filename in results:
|
||||||
use_segment = use_segments
|
use_segment = use_segments
|
||||||
|
|
||||||
|
@ -1541,6 +1542,11 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
|
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jobs = {
|
||||||
|
'quantize': [[], []],
|
||||||
|
'phonemize': [[], []],
|
||||||
|
}
|
||||||
|
|
||||||
for file in enumerate_progress(segments, desc="Parsing segments", progress=progress):
|
for file in enumerate_progress(segments, desc="Parsing segments", progress=progress):
|
||||||
result = segments[file]
|
result = segments[file]
|
||||||
path = f'{indir}/audio/{file}'
|
path = f'{indir}/audio/{file}'
|
||||||
|
@ -1559,9 +1565,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
phonemes = result['phonemes']
|
phonemes = result['phonemes']
|
||||||
if phonemize and phonemes is None:
|
if phonemize and phonemes is None:
|
||||||
phonemes = phonemizer( text, language=lang )
|
phonemes = phonemizer( text, language=lang )
|
||||||
if phonemize:
|
|
||||||
text = phonemes
|
|
||||||
|
|
||||||
normalized = normalizer(text) if normalize else text
|
normalized = normalizer(text) if normalize else text
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
|
@ -1587,7 +1591,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
if not culled and audio_length > 0:
|
if not culled and audio_length > 0:
|
||||||
culled = duration < audio_length
|
culled = duration < audio_length
|
||||||
|
|
||||||
line = f'audio/{file}|{text}'
|
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
|
||||||
|
|
||||||
lines['training' if not culled else 'validation'].append(line)
|
lines['training' if not culled else 'validation'].append(line)
|
||||||
|
|
||||||
|
@ -1596,16 +1600,42 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||||
|
|
||||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
|
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
|
||||||
from vall_e.emb.qnt import encode as quantize
|
if not os.path.exists(qnt_file):
|
||||||
quantized = quantize( waveform, sample_rate ).cpu()
|
jobs['quantize'][0].append(qnt_file)
|
||||||
|
jobs['quantize'][1].append((waveform, sample_rate))
|
||||||
|
"""
|
||||||
|
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||||
print("Quantized:", file)
|
print("Quantized:", file)
|
||||||
|
"""
|
||||||
|
|
||||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
|
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
|
||||||
from vall_e.emb.g2p import encode as phonemize
|
if not os.path.exists(phn_file):
|
||||||
phonemized = phonemize( normalized )
|
jobs['phonemize'][0].append(phn_file)
|
||||||
|
jobs['phonemize'][1].append(normalized)
|
||||||
|
"""
|
||||||
|
phonemized = valle_phonemize( normalized )
|
||||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
|
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||||
|
print("Phonemized:", file, normalized, text)
|
||||||
|
"""
|
||||||
|
|
||||||
|
for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress):
|
||||||
|
qnt_file = jobs['quantize'][0][i]
|
||||||
|
waveform, sample_rate = jobs['quantize'][1][i]
|
||||||
|
|
||||||
|
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||||
|
torch.save(quantized, qnt_file)
|
||||||
|
print("Quantized:", file)
|
||||||
|
|
||||||
|
for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress):
|
||||||
|
phn_file = jobs['phonemize'][0][i]
|
||||||
|
normalized = jobs['phonemize'][1][i]
|
||||||
|
|
||||||
|
phonemized = valle_phonemize( normalized )
|
||||||
|
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||||
|
print("Phonemized:", file)
|
||||||
|
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
@ -2635,6 +2665,7 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
global whisper_vad
|
global whisper_vad
|
||||||
global whisper_diarize
|
global whisper_diarize
|
||||||
|
global whisper_align_model
|
||||||
|
|
||||||
if args.whisper_backend not in WHISPER_BACKENDS:
|
if args.whisper_backend not in WHISPER_BACKENDS:
|
||||||
raise Exception(f"unavailable backend: {args.whisper_backend}")
|
raise Exception(f"unavailable backend: {args.whisper_backend}")
|
||||||
|
@ -2683,13 +2714,31 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
)
|
)
|
||||||
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
|
||||||
|
|
||||||
print("Loaded Whisper model")
|
print("Loaded Whisper model")
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
global whisper_vad
|
||||||
|
global whisper_diarize
|
||||||
|
global whisper_align_model
|
||||||
|
|
||||||
|
if whisper_vad:
|
||||||
|
del whisper_vad
|
||||||
|
whisper_vad = None
|
||||||
|
|
||||||
|
if whisper_diarize:
|
||||||
|
del whisper_diarize
|
||||||
|
whisper_diarize = None
|
||||||
|
|
||||||
|
if whisper_align_model:
|
||||||
|
del whisper_align_model
|
||||||
|
whisper_align_model = None
|
||||||
|
|
||||||
if whisper_model:
|
if whisper_model:
|
||||||
del whisper_model
|
del whisper_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user