remove redundant phonemize for vall-e (oops), quantize all files and then phonemize all files for cope optimization, load alignment model once instead of for every transcription (speedup with whisperx)

master
mrq 2023-03-23 00:22:25 +07:00
parent 19c0854e6a
commit d2a9ab9e41
2 changed files with 66 additions and 21 deletions

@ -1,7 +1,3 @@
data_root: ./training/${voice}/
ckpt_root: ./training/${voice}/ckpt/
log_root: ./training/${voice}/logs/
data_dirs: [./training/${voice}/valle/]
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"

@ -65,8 +65,7 @@ MAX_TRAINING_DURATION = 11.6097505669
VALLE_ENABLED = False
try:
from vall_e.emb.qnt import encode as quantize
# from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as valle_quantize
VALLE_ENABLED = True
except Exception as e:
@ -80,9 +79,12 @@ tts = None
tts_loading = False
webui = None
voicefixer = None
whisper_model = None
whisper_vad = None
whisper_diarize = None
whisper_align_model = None
training_state = None
current_voice = None
@ -1165,6 +1167,8 @@ def whisper_transcribe( file, language=None ):
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if not whisper_model:
load_whisper_model(language=language)
@ -1208,7 +1212,7 @@ def whisper_transcribe( file, language=None ):
else:
result = whisper_model.transcribe(file)
align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
align_model, metadata = whisper_align_model
result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device)
if whisper_diarize:
@ -1462,9 +1466,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
lines = { 'training': [], 'validation': [] }
segments = {}
if args.tts_backend == "vall-e":
phonemize = True
for filename in results:
use_segment = use_segments
@ -1541,6 +1542,11 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
}
jobs = {
'quantize': [[], []],
'phonemize': [[], []],
}
for file in enumerate_progress(segments, desc="Parsing segments", progress=progress):
result = segments[file]
path = f'{indir}/audio/{file}'
@ -1559,9 +1565,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
phonemes = result['phonemes']
if phonemize and phonemes is None:
phonemes = phonemizer( text, language=lang )
if phonemize:
text = phonemes
normalized = normalizer(text) if normalize else text
if len(text) > 200:
@ -1587,7 +1591,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if not culled and audio_length > 0:
culled = duration < audio_length
line = f'audio/{file}|{text}'
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
lines['training' if not culled else 'validation'].append(line)
@ -1596,16 +1600,42 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
os.makedirs(f'{indir}/valle/', exist_ok=True)
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
from vall_e.emb.qnt import encode as quantize
quantized = quantize( waveform, sample_rate ).cpu()
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
if not os.path.exists(qnt_file):
jobs['quantize'][0].append(qnt_file)
jobs['quantize'][1].append((waveform, sample_rate))
"""
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
from vall_e.emb.g2p import encode as phonemize
phonemized = phonemize( normalized )
"""
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
if not os.path.exists(phn_file):
jobs['phonemize'][0].append(phn_file)
jobs['phonemize'][1].append(normalized)
"""
phonemized = valle_phonemize( normalized )
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
print("Phonemized:", file, normalized, text)
"""
for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress):
qnt_file = jobs['quantize'][0][i]
waveform, sample_rate = jobs['quantize'][1][i]
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, qnt_file)
print("Quantized:", file)
for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress):
phn_file = jobs['phonemize'][0][i]
normalized = jobs['phonemize'][1][i]
phonemized = valle_phonemize( normalized )
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
print("Phonemized:", file)
training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation'])
@ -2635,6 +2665,7 @@ def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if args.whisper_backend not in WHISPER_BACKENDS:
raise Exception(f"unavailable backend: {args.whisper_backend}")
@ -2683,13 +2714,31 @@ def load_whisper_model(language=None, model_name=None, progress=None):
device=torch.device(device),
)
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
except Exception as e:
pass
whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device)
print("Loaded Whisper model")
def unload_whisper():
global whisper_model
global whisper_vad
global whisper_diarize
global whisper_align_model
if whisper_vad:
del whisper_vad
whisper_vad = None
if whisper_diarize:
del whisper_diarize
whisper_diarize = None
if whisper_align_model:
del whisper_align_model
whisper_align_model = None
if whisper_model:
del whisper_model