From 547e1d1277a7022c82c2b305d14df378d2366971 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 3 Jul 2023 01:22:02 +0000 Subject: [PATCH] updated bark support, it'll also query for vocos, it actually works (I don't know what specifically was the issue) --- src/utils.py | 128 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 109 insertions(+), 19 deletions(-) diff --git a/src/utils.py b/src/utils.py index ac2209a..71d0592 100755 --- a/src/utils.py +++ b/src/utils.py @@ -86,6 +86,7 @@ if VALLE_ENABLED: TTSES.append('vall-e') try: + from bark import text_to_semantic from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model from bark.api import generate_audio as bark_generate_audio from encodec.utils import convert_audio @@ -98,12 +99,48 @@ except Exception as e: raise e pass +if BARK_ENABLED: + try: + from vocos import Vocos + VOCOS_ENABLED = True + except Exception as e: + VOCOS_ENABLED = False + + try: + from hubert.hubert_manager import HuBERTManager + + HUBERT_ENABLED = True + except Exception as e: + HUBERT_ENABLED = False + if BARK_ENABLED: TTSES.append('bark') + + def semantic_to_audio_tokens( + semantic_tokens, + history_prompt = None, + temp = 0.7, + silent = False, + output_full = False, + ): + coarse_tokens = generate_coarse( + semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True + ) + fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5) + + if output_full: + full_generation = { + "semantic_prompt": semantic_tokens, + "coarse_prompt": coarse_tokens, + "fine_prompt": fine_tokens, + } + return full_generation + return fine_tokens + class Bark_TTS(): def __init__(self, small=False): self.input_sample_rate = BARK_SAMPLE_RATE - self.output_sample_rate = args.output_sample_rate + self.output_sample_rate = BARK_SAMPLE_RATE # args.output_sample_rate preload_models( text_use_gpu=True, @@ -118,7 +155,12 @@ if BARK_ENABLED: force_reload=False ) - def create_voice( self, voice, device='cuda' ): + self.device = get_device_name() + + if VOCOS_ENABLED: + self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device) + + def create_voice( self, voice ): transcription_json = f'./training/{voice}/whisper.json' if not os.path.exists(transcription_json): raise f"Transcription for voice not found: {voice}" @@ -146,29 +188,75 @@ if BARK_ENABLED: model = load_codec_model(use_gpu=True) wav, sr = torchaudio.load(audio_filepath) wav = convert_audio(wav, sr, model.sample_rate, model.channels) - wav = wav.unsqueeze(0).to(device) - - # Extract discrete codes from EnCodec - with torch.no_grad(): - encoded_frames = model.encode(wav) - codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T] - # get seconds of audio - seconds = wav.shape[-1] / model.sample_rate # generate semantic tokens - semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) + + if HUBERT_ENABLED: + wav = wav.to(device) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = model.encode(wav.unsqueeze(0)) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] + + # get seconds of audio + seconds = wav.shape[-1] / model.sample_rate + + hubert_manager = HuBERTManager() + hubert_manager.make_sure_hubert_installed() + hubert_manager.make_sure_tokenizer_installed() + + from hubert.pre_kmeans_hubert import CustomHubert + from hubert.customtokenizer import CustomTokenizer + + # Load the HuBERT model + hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device) + + # Load the CustomTokenizer model + tokenizer = CustomTokenizer.load_from_checkpoint('./models/hubert/tokenizer.pth').to(device) + else: + # Load and pre-process the audio waveform + model = load_codec_model(use_gpu=True) + wav, sr = torchaudio.load(audio_filepath) + wav = convert_audio(wav, sr, model.sample_rate, model.channels) + wav = wav.unsqueeze(0).to(self.device) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = model.encode(wav) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T] + + # get seconds of audio + seconds = wav.shape[-1] / model.sample_rate + + # generate semantic tokens + semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz' np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ): - if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'): - self.create_voice( voice ) - voice = voice.replace("/", "_") - if voice not in ALLOWED_PROMPTS: - ALLOWED_PROMPTS.add( voice ) + if voice == "random": + voice = None + else: + if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'): + self.create_voice( voice ) + voice = voice.replace("/", "_") + if voice not in ALLOWED_PROMPTS: + ALLOWED_PROMPTS.add( voice ) + + semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False) + audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False ) + + if VOCOS_ENABLED: + audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device) + features = self.vocos.codes_to_features(audio_tokens_torch) + wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device)) + else: + wav = codec_decode( audio_tokens ) - return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE) + return ( wav, BARK_SAMPLE_RATE ) + # return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE) args = None tts = None @@ -371,8 +459,10 @@ def generate_bark(**kwargs): settings['datetime'] = datetime.now().isoformat() # save here in case some error happens mid-batch - #torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) - write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav) + if VOCOS_ENABLED: + torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr) + else: + write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav) wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav') audio_cache[name] = {