uses gitmylo/bark-voice-cloning-HuBERT-quantizer for creating custom voices (it slightly works better over the base method, but still not very good desu)

This commit is contained in:
mrq 2023-07-03 02:46:10 +00:00
parent 547e1d1277
commit 6c3f48efba
2 changed files with 47 additions and 27 deletions

View File

@ -42,9 +42,6 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer
from whisper.tokenizer import LANGUAGES
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
@ -68,6 +65,19 @@ MAX_TRAINING_DURATION = 11.6097505669
VALLE_ENABLED = False
BARK_ENABLED = False
VERBOSE_DEBUG = True
try:
from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer
from whisper.tokenizer import LANGUAGES
print("Whisper detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
pass
try:
from vall_e.emb.qnt import encode as valle_quantize
from vall_e.emb.g2p import encode as valle_phonemize
@ -76,10 +86,11 @@ try:
import soundfile
print("VALL-E detected")
VALLE_ENABLED = True
except Exception as e:
if False: # args.tts_backend == "vall-e":
raise e
if VERBOSE_DEBUG:
print("Error:", e)
pass
if VALLE_ENABLED:
@ -93,27 +104,39 @@ try:
from scipy.io.wavfile import write as write_wav
print("Bark detected")
BARK_ENABLED = True
except Exception as e:
if False: # args.tts_backend == "bark":
raise e
if VERBOSE_DEBUG:
print("Error:", e)
pass
if BARK_ENABLED:
try:
from vocos import Vocos
VOCOS_ENABLED = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
VOCOS_ENABLED = False
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
HUBERT_ENABLED = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
HUBERT_ENABLED = False
if BARK_ENABLED:
TTSES.append('bark')
def semantic_to_audio_tokens(
@ -190,9 +213,9 @@ if BARK_ENABLED:
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
# generate semantic tokens
if HUBERT_ENABLED:
wav = wav.to(device)
wav = wav.to(self.device)
# Extract discrete codes from EnCodec
with torch.no_grad():
@ -202,23 +225,20 @@ if BARK_ENABLED:
# get seconds of audio
seconds = wav.shape[-1] / model.sample_rate
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
# Load the HuBERT model
hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device)
hubert_model = CustomHubert(checkpoint_path='./data/models/hubert/hubert.pt').to(self.device)
# Load the CustomTokenizer model
tokenizer = CustomTokenizer.load_from_checkpoint('./models/hubert/tokenizer.pth').to(device)
tokenizer = CustomTokenizer.load_from_checkpoint('./data/models/hubert/tokenizer.pth').to(self.device)
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
# move codes to cpu
codes = codes.cpu().numpy()
# move semantic tokens to cpu
semantic_tokens = semantic_tokens.cpu().numpy()
else:
# Load and pre-process the audio waveform
model = load_codec_model(use_gpu=True)
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0).to(self.device)
# Extract discrete codes from EnCodec
@ -1358,6 +1378,10 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, orig
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
if args.tts_backend == "bark":
tts.create_voice( voice )
return
if args.autoregressive_model == "auto":
tts.load_autoregressive_model(deduce_autoregressive_model(voice))

View File

@ -169,10 +169,6 @@ def reset_generate_settings_proxy():
return tuple(res)
def compute_latents_proxy(voice, voice_latents_chunks, original_ar, original_diffusion, progress=gr.Progress(track_tqdm=True)):
if args.tts_backend == "bark":
global tts
tts.create_voice( voice )
return voice
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, original_ar=original_ar, original_diffusion=original_diffusion )
return voice