forked from mrq/ai-voice-cloning
updated bark support, it'll also query for vocos, it actually works (I don't know what specifically was the issue)
This commit is contained in:
parent
76ed34ddd2
commit
547e1d1277
128
src/utils.py
128
src/utils.py
|
@ -86,6 +86,7 @@ if VALLE_ENABLED:
|
||||||
TTSES.append('vall-e')
|
TTSES.append('vall-e')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from bark import text_to_semantic
|
||||||
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
|
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
|
||||||
from bark.api import generate_audio as bark_generate_audio
|
from bark.api import generate_audio as bark_generate_audio
|
||||||
from encodec.utils import convert_audio
|
from encodec.utils import convert_audio
|
||||||
|
@ -98,12 +99,48 @@ except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if BARK_ENABLED:
|
||||||
|
try:
|
||||||
|
from vocos import Vocos
|
||||||
|
VOCOS_ENABLED = True
|
||||||
|
except Exception as e:
|
||||||
|
VOCOS_ENABLED = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hubert.hubert_manager import HuBERTManager
|
||||||
|
|
||||||
|
HUBERT_ENABLED = True
|
||||||
|
except Exception as e:
|
||||||
|
HUBERT_ENABLED = False
|
||||||
|
|
||||||
if BARK_ENABLED:
|
if BARK_ENABLED:
|
||||||
TTSES.append('bark')
|
TTSES.append('bark')
|
||||||
|
|
||||||
|
def semantic_to_audio_tokens(
|
||||||
|
semantic_tokens,
|
||||||
|
history_prompt = None,
|
||||||
|
temp = 0.7,
|
||||||
|
silent = False,
|
||||||
|
output_full = False,
|
||||||
|
):
|
||||||
|
coarse_tokens = generate_coarse(
|
||||||
|
semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True
|
||||||
|
)
|
||||||
|
fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)
|
||||||
|
|
||||||
|
if output_full:
|
||||||
|
full_generation = {
|
||||||
|
"semantic_prompt": semantic_tokens,
|
||||||
|
"coarse_prompt": coarse_tokens,
|
||||||
|
"fine_prompt": fine_tokens,
|
||||||
|
}
|
||||||
|
return full_generation
|
||||||
|
return fine_tokens
|
||||||
|
|
||||||
class Bark_TTS():
|
class Bark_TTS():
|
||||||
def __init__(self, small=False):
|
def __init__(self, small=False):
|
||||||
self.input_sample_rate = BARK_SAMPLE_RATE
|
self.input_sample_rate = BARK_SAMPLE_RATE
|
||||||
self.output_sample_rate = args.output_sample_rate
|
self.output_sample_rate = BARK_SAMPLE_RATE # args.output_sample_rate
|
||||||
|
|
||||||
preload_models(
|
preload_models(
|
||||||
text_use_gpu=True,
|
text_use_gpu=True,
|
||||||
|
@ -118,7 +155,12 @@ if BARK_ENABLED:
|
||||||
force_reload=False
|
force_reload=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_voice( self, voice, device='cuda' ):
|
self.device = get_device_name()
|
||||||
|
|
||||||
|
if VOCOS_ENABLED:
|
||||||
|
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
|
||||||
|
|
||||||
|
def create_voice( self, voice ):
|
||||||
transcription_json = f'./training/{voice}/whisper.json'
|
transcription_json = f'./training/{voice}/whisper.json'
|
||||||
if not os.path.exists(transcription_json):
|
if not os.path.exists(transcription_json):
|
||||||
raise f"Transcription for voice not found: {voice}"
|
raise f"Transcription for voice not found: {voice}"
|
||||||
|
@ -146,29 +188,75 @@ if BARK_ENABLED:
|
||||||
model = load_codec_model(use_gpu=True)
|
model = load_codec_model(use_gpu=True)
|
||||||
wav, sr = torchaudio.load(audio_filepath)
|
wav, sr = torchaudio.load(audio_filepath)
|
||||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
wav = wav.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
# Extract discrete codes from EnCodec
|
|
||||||
with torch.no_grad():
|
|
||||||
encoded_frames = model.encode(wav)
|
|
||||||
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
|
|
||||||
|
|
||||||
# get seconds of audio
|
|
||||||
seconds = wav.shape[-1] / model.sample_rate
|
|
||||||
# generate semantic tokens
|
# generate semantic tokens
|
||||||
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
|
|
||||||
|
if HUBERT_ENABLED:
|
||||||
|
wav = wav.to(device)
|
||||||
|
|
||||||
|
# Extract discrete codes from EnCodec
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_frames = model.encode(wav.unsqueeze(0))
|
||||||
|
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
|
||||||
|
|
||||||
|
# get seconds of audio
|
||||||
|
seconds = wav.shape[-1] / model.sample_rate
|
||||||
|
|
||||||
|
hubert_manager = HuBERTManager()
|
||||||
|
hubert_manager.make_sure_hubert_installed()
|
||||||
|
hubert_manager.make_sure_tokenizer_installed()
|
||||||
|
|
||||||
|
from hubert.pre_kmeans_hubert import CustomHubert
|
||||||
|
from hubert.customtokenizer import CustomTokenizer
|
||||||
|
|
||||||
|
# Load the HuBERT model
|
||||||
|
hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device)
|
||||||
|
|
||||||
|
# Load the CustomTokenizer model
|
||||||
|
tokenizer = CustomTokenizer.load_from_checkpoint('./models/hubert/tokenizer.pth').to(device)
|
||||||
|
else:
|
||||||
|
# Load and pre-process the audio waveform
|
||||||
|
model = load_codec_model(use_gpu=True)
|
||||||
|
wav, sr = torchaudio.load(audio_filepath)
|
||||||
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
|
wav = wav.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
# Extract discrete codes from EnCodec
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_frames = model.encode(wav)
|
||||||
|
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
|
||||||
|
|
||||||
|
# get seconds of audio
|
||||||
|
seconds = wav.shape[-1] / model.sample_rate
|
||||||
|
|
||||||
|
# generate semantic tokens
|
||||||
|
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
|
||||||
|
|
||||||
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
|
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
|
||||||
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
|
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
|
||||||
|
|
||||||
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
|
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
|
||||||
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
|
if voice == "random":
|
||||||
self.create_voice( voice )
|
voice = None
|
||||||
voice = voice.replace("/", "_")
|
else:
|
||||||
if voice not in ALLOWED_PROMPTS:
|
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
|
||||||
ALLOWED_PROMPTS.add( voice )
|
self.create_voice( voice )
|
||||||
|
voice = voice.replace("/", "_")
|
||||||
|
if voice not in ALLOWED_PROMPTS:
|
||||||
|
ALLOWED_PROMPTS.add( voice )
|
||||||
|
|
||||||
return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
|
semantic_tokens = text_to_semantic(text, history_prompt=voice, temp=text_temp, silent=False)
|
||||||
|
audio_tokens = semantic_to_audio_tokens( semantic_tokens, history_prompt=voice, temp=waveform_temp, silent=False, output_full=False )
|
||||||
|
|
||||||
|
if VOCOS_ENABLED:
|
||||||
|
audio_tokens_torch = torch.from_numpy(audio_tokens).to(self.device)
|
||||||
|
features = self.vocos.codes_to_features(audio_tokens_torch)
|
||||||
|
wav = self.vocos.decode(features, bandwidth_id=torch.tensor([2], device=self.device))
|
||||||
|
else:
|
||||||
|
wav = codec_decode( audio_tokens )
|
||||||
|
|
||||||
|
return ( wav, BARK_SAMPLE_RATE )
|
||||||
|
# return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
|
@ -371,8 +459,10 @@ def generate_bark(**kwargs):
|
||||||
settings['datetime'] = datetime.now().isoformat()
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
if VOCOS_ENABLED:
|
||||||
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||||
|
else:
|
||||||
|
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
||||||
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user