diff --git a/src/utils.py b/src/utils.py index 71d0592..5eb536b 100755 --- a/src/utils.py +++ b/src/utils.py @@ -42,9 +42,6 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_ from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc -from whisper.normalizers.english import EnglishTextNormalizer -from whisper.normalizers.basic import BasicTextNormalizer -from whisper.tokenizer import LANGUAGES MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" @@ -68,6 +65,19 @@ MAX_TRAINING_DURATION = 11.6097505669 VALLE_ENABLED = False BARK_ENABLED = False +VERBOSE_DEBUG = True + +try: + from whisper.normalizers.english import EnglishTextNormalizer + from whisper.normalizers.basic import BasicTextNormalizer + from whisper.tokenizer import LANGUAGES + + print("Whisper detected") +except Exception as e: + if VERBOSE_DEBUG: + print("Error:", e) + pass + try: from vall_e.emb.qnt import encode as valle_quantize from vall_e.emb.g2p import encode as valle_phonemize @@ -76,10 +86,11 @@ try: import soundfile + print("VALL-E detected") VALLE_ENABLED = True except Exception as e: - if False: # args.tts_backend == "vall-e": - raise e + if VERBOSE_DEBUG: + print("Error:", e) pass if VALLE_ENABLED: @@ -93,27 +104,39 @@ try: from scipy.io.wavfile import write as write_wav + print("Bark detected") BARK_ENABLED = True except Exception as e: - if False: # args.tts_backend == "bark": - raise e + if VERBOSE_DEBUG: + print("Error:", e) pass if BARK_ENABLED: try: from vocos import Vocos VOCOS_ENABLED = True + print("Vocos detected") except Exception as e: + if VERBOSE_DEBUG: + print("Error:", e) VOCOS_ENABLED = False try: from hubert.hubert_manager import HuBERTManager + from hubert.pre_kmeans_hubert import CustomHubert + from hubert.customtokenizer import CustomTokenizer + + hubert_manager = HuBERTManager() + hubert_manager.make_sure_hubert_installed() + hubert_manager.make_sure_tokenizer_installed() HUBERT_ENABLED = True + print("HuBERT detected") except Exception as e: + if VERBOSE_DEBUG: + print("Error:", e) HUBERT_ENABLED = False -if BARK_ENABLED: TTSES.append('bark') def semantic_to_audio_tokens( @@ -190,9 +213,9 @@ if BARK_ENABLED: wav = convert_audio(wav, sr, model.sample_rate, model.channels) # generate semantic tokens - + if HUBERT_ENABLED: - wav = wav.to(device) + wav = wav.to(self.device) # Extract discrete codes from EnCodec with torch.no_grad(): @@ -202,23 +225,20 @@ if BARK_ENABLED: # get seconds of audio seconds = wav.shape[-1] / model.sample_rate - hubert_manager = HuBERTManager() - hubert_manager.make_sure_hubert_installed() - hubert_manager.make_sure_tokenizer_installed() - - from hubert.pre_kmeans_hubert import CustomHubert - from hubert.customtokenizer import CustomTokenizer - # Load the HuBERT model - hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device) + hubert_model = CustomHubert(checkpoint_path='./data/models/hubert/hubert.pt').to(self.device) # Load the CustomTokenizer model - tokenizer = CustomTokenizer.load_from_checkpoint('./models/hubert/tokenizer.pth').to(device) + tokenizer = CustomTokenizer.load_from_checkpoint('./data/models/hubert/tokenizer.pth').to(self.device) + + semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate) + semantic_tokens = tokenizer.get_token(semantic_vectors) + + # move codes to cpu + codes = codes.cpu().numpy() + # move semantic tokens to cpu + semantic_tokens = semantic_tokens.cpu().numpy() else: - # Load and pre-process the audio waveform - model = load_codec_model(use_gpu=True) - wav, sr = torchaudio.load(audio_filepath) - wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.unsqueeze(0).to(self.device) # Extract discrete codes from EnCodec @@ -1358,6 +1378,10 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, orig if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") + if args.tts_backend == "bark": + tts.create_voice( voice ) + return + if args.autoregressive_model == "auto": tts.load_autoregressive_model(deduce_autoregressive_model(voice)) diff --git a/src/webui.py b/src/webui.py index 368ec64..329c237 100755 --- a/src/webui.py +++ b/src/webui.py @@ -169,10 +169,6 @@ def reset_generate_settings_proxy(): return tuple(res) def compute_latents_proxy(voice, voice_latents_chunks, original_ar, original_diffusion, progress=gr.Progress(track_tqdm=True)): - if args.tts_backend == "bark": - global tts - tts.create_voice( voice ) - return voice compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, original_ar=original_ar, original_diffusion=original_diffusion ) return voice