From d2a9ab9e4120bb6eeb333d5f55c2fd4802b4d33a Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Mar 2023 00:22:25 +0000 Subject: [PATCH] remove redundant phonemize for vall-e (oops), quantize all files and then phonemize all files for cope optimization, load alignment model once instead of for every transcription (speedup with whisperx) --- models/.template.valle.yaml | 4 -- src/utils.py | 81 +++++++++++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/models/.template.valle.yaml b/models/.template.valle.yaml index 33d74df..91a5dd5 100755 --- a/models/.template.valle.yaml +++ b/models/.template.valle.yaml @@ -1,7 +1,3 @@ -data_root: ./training/${voice}/ -ckpt_root: ./training/${voice}/ckpt/ -log_root: ./training/${voice}/logs/ - data_dirs: [./training/${voice}/valle/] spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]" diff --git a/src/utils.py b/src/utils.py index 1cda2c1..80822e8 100755 --- a/src/utils.py +++ b/src/utils.py @@ -65,8 +65,7 @@ MAX_TRAINING_DURATION = 11.6097505669 VALLE_ENABLED = False try: - from vall_e.emb.qnt import encode as quantize - # from vall_e.emb.g2p import encode as phonemize + from vall_e.emb.qnt import encode as valle_quantize VALLE_ENABLED = True except Exception as e: @@ -80,9 +79,12 @@ tts = None tts_loading = False webui = None voicefixer = None + whisper_model = None whisper_vad = None whisper_diarize = None +whisper_align_model = None + training_state = None current_voice = None @@ -1165,6 +1167,8 @@ def whisper_transcribe( file, language=None ): global whisper_model global whisper_vad global whisper_diarize + global whisper_align_model + if not whisper_model: load_whisper_model(language=language) @@ -1208,7 +1212,7 @@ def whisper_transcribe( file, language=None ): else: result = whisper_model.transcribe(file) - align_model, metadata = whisperx.load_align_model(language_code=result["language"], device=device) + align_model, metadata = whisper_align_model result_aligned = whisperx.align(result["segments"], align_model, metadata, file, device) if whisper_diarize: @@ -1462,9 +1466,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p lines = { 'training': [], 'validation': [] } segments = {} - if args.tts_backend == "vall-e": - phonemize = True - for filename in results: use_segment = use_segments @@ -1541,6 +1542,11 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p 'phonemes': segment['phonemes'] if 'phonemes' in segment else None } + jobs = { + 'quantize': [[], []], + 'phonemize': [[], []], + } + for file in enumerate_progress(segments, desc="Parsing segments", progress=progress): result = segments[file] path = f'{indir}/audio/{file}' @@ -1559,9 +1565,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p phonemes = result['phonemes'] if phonemize and phonemes is None: phonemes = phonemizer( text, language=lang ) - if phonemize: - text = phonemes - + normalized = normalizer(text) if normalize else text if len(text) > 200: @@ -1587,7 +1591,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p if not culled and audio_length > 0: culled = duration < audio_length - line = f'audio/{file}|{text}' + line = f'audio/{file}|{phonemes if phonemize and phonemes else text}' lines['training' if not culled else 'validation'].append(line) @@ -1596,16 +1600,42 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p os.makedirs(f'{indir}/valle/', exist_ok=True) - if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'): - from vall_e.emb.qnt import encode as quantize - quantized = quantize( waveform, sample_rate ).cpu() + qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}' + if not os.path.exists(qnt_file): + jobs['quantize'][0].append(qnt_file) + jobs['quantize'][1].append((waveform, sample_rate)) + """ + quantized = valle_quantize( waveform, sample_rate ).cpu() torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') print("Quantized:", file) + """ - if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'): - from vall_e.emb.g2p import encode as phonemize - phonemized = phonemize( normalized ) + phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}' + if not os.path.exists(phn_file): + jobs['phonemize'][0].append(phn_file) + jobs['phonemize'][1].append(normalized) + """ + phonemized = valle_phonemize( normalized ) open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized)) + print("Phonemized:", file, normalized, text) + """ + + for i in enumerate_progress(range(len(jobs['quantize'][0])), desc="Quantizing", progress=progress): + qnt_file = jobs['quantize'][0][i] + waveform, sample_rate = jobs['quantize'][1][i] + + quantized = valle_quantize( waveform, sample_rate ).cpu() + torch.save(quantized, qnt_file) + print("Quantized:", file) + + for i in enumerate_progress(range(len(jobs['phonemize'][0])), desc="Phonemizing", progress=progress): + phn_file = jobs['phonemize'][0][i] + normalized = jobs['phonemize'][1][i] + + phonemized = valle_phonemize( normalized ) + open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) + print("Phonemized:", file) + training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -2635,6 +2665,7 @@ def load_whisper_model(language=None, model_name=None, progress=None): global whisper_model global whisper_vad global whisper_diarize + global whisper_align_model if args.whisper_backend not in WHISPER_BACKENDS: raise Exception(f"unavailable backend: {args.whisper_backend}") @@ -2683,13 +2714,31 @@ def load_whisper_model(language=None, model_name=None, progress=None): device=torch.device(device), ) whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) + except Exception as e: pass + + whisper_align_model = whisperx.load_align_model(model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" if language=="en" else None, language_code=language, device=device) print("Loaded Whisper model") def unload_whisper(): global whisper_model + global whisper_vad + global whisper_diarize + global whisper_align_model + + if whisper_vad: + del whisper_vad + whisper_vad = None + + if whisper_diarize: + del whisper_diarize + whisper_diarize = None + + if whisper_align_model: + del whisper_align_model + whisper_align_model = None if whisper_model: del whisper_model