forked from mrq/ai-voice-cloning
added (yet another) experimental voice latent calculation mode (when chunk size is 0 and theres a dataset generated, itll leverage it by padding to a common size then computing them, should help avoid splitting mid-phoneme)
This commit is contained in:
parent
5063728bb0
commit
3899f9b4e3
66
src/utils.py
66
src/utils.py
|
@ -31,7 +31,7 @@ import pandas as pd
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech, MODELS, get_model_path
|
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name
|
from tortoise.utils.device import get_device_name, set_device_name
|
||||||
|
@ -89,6 +89,8 @@ def generate(
|
||||||
if tts_loading:
|
if tts_loading:
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
load_tts()
|
load_tts()
|
||||||
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
@ -121,17 +123,8 @@ def generate(
|
||||||
voice_samples, conditioning_latents = load_voice(voice)
|
voice_samples, conditioning_latents = load_voice(voice)
|
||||||
|
|
||||||
if voice_samples and len(voice_samples) > 0:
|
if voice_samples and len(voice_samples) > 0:
|
||||||
|
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
||||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||||
|
|
||||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
|
||||||
if len(conditioning_latents) == 4:
|
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
|
||||||
|
|
||||||
if voice != "microphone":
|
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
|
||||||
else:
|
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
else:
|
else:
|
||||||
if conditioning_latents is not None:
|
if conditioning_latents is not None:
|
||||||
|
@ -551,6 +544,10 @@ def update_baseline_for_latents_chunks( voice ):
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
dataset_file = f'./training/{voice}/train.txt'
|
||||||
|
if os.path.exists(dataset_file):
|
||||||
|
return 0 # 0 will leverage using the LJspeech dataset for computing latents
|
||||||
|
|
||||||
files = os.listdir(path)
|
files = os.listdir(path)
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
|
@ -565,11 +562,13 @@ def update_baseline_for_latents_chunks( voice ):
|
||||||
total_duration += duration
|
total_duration += duration
|
||||||
total = total + 1
|
total = total + 1
|
||||||
|
|
||||||
|
|
||||||
|
# brain too fried to figure out a better way
|
||||||
if args.autocalculate_voice_chunk_duration_size == 0:
|
if args.autocalculate_voice_chunk_duration_size == 0:
|
||||||
return int(total_duration / total) if total > 0 else 1
|
return int(total_duration / total) if total > 0 else 1
|
||||||
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, progress=None):
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
|
||||||
|
@ -581,12 +580,42 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
load_tts()
|
load_tts()
|
||||||
|
|
||||||
voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
if voice:
|
||||||
|
load_from_dataset = voice_latents_chunks == 0
|
||||||
|
|
||||||
|
if load_from_dataset:
|
||||||
|
dataset_path = f'./training/{voice}/train.txt'
|
||||||
|
if not os.path.exists(dataset_path):
|
||||||
|
load_from_dataset = False
|
||||||
|
else:
|
||||||
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
print("Leveraging LJSpeech dataset for computing latents")
|
||||||
|
|
||||||
|
voice_samples = []
|
||||||
|
max_length = 0
|
||||||
|
for line in lines:
|
||||||
|
filename = f'./training/{voice}/{line.split("|")[0]}'
|
||||||
|
|
||||||
|
waveform = load_audio(filename, 22050)
|
||||||
|
max_length = max(max_length, waveform.shape[-1])
|
||||||
|
voice_samples.append(waveform)
|
||||||
|
|
||||||
|
for i in range(len(voice_samples)):
|
||||||
|
voice_samples[i] = pad_or_truncate(voice_samples[i], max_length)
|
||||||
|
|
||||||
|
voice_latents_chunks = len(voice_samples)
|
||||||
|
if not load_from_dataset:
|
||||||
|
voice_samples, _ = load_voice(voice, load_latents=False)
|
||||||
|
|
||||||
if voice_samples is None:
|
if voice_samples is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, progress=progress)
|
||||||
|
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
@ -596,7 +625,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
else:
|
else:
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||||
|
|
||||||
return voice
|
return conditioning_latents
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
|
@ -1847,6 +1876,10 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
if tts_loading:
|
if tts_loading:
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
|
||||||
print(f"Loading model: {autoregressive_model_path}")
|
print(f"Loading model: {autoregressive_model_path}")
|
||||||
tts.load_autoregressive_model(autoregressive_model_path)
|
tts.load_autoregressive_model(autoregressive_model_path)
|
||||||
|
@ -1867,6 +1900,9 @@ def update_vocoder_model(vocoder_model):
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
print(f"Loading model: {vocoder_model}")
|
print(f"Loading model: {vocoder_model}")
|
||||||
tts.load_vocoder_model(vocoder_model)
|
tts.load_vocoder_model(vocoder_model)
|
||||||
print(f"Loaded model: {tts.vocoder_model}")
|
print(f"Loaded model: {tts.vocoder_model}")
|
||||||
|
|
|
@ -163,6 +163,11 @@ def history_view_results( voice ):
|
||||||
gr.Dropdown.update(choices=sorted(files))
|
gr.Dropdown.update(choices=sorted(files))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||||
|
return voice
|
||||||
|
|
||||||
|
|
||||||
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
|
def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)):
|
||||||
import_voices(files, name, progress)
|
import_voices(files, name, progress)
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
@ -387,7 +392,7 @@ def setup_gradio():
|
||||||
prompt = gr.Textbox(lines=1, label="Custom Emotion")
|
prompt = gr.Textbox(lines=1, label="Custom Emotion")
|
||||||
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
|
||||||
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
|
mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
|
||||||
voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1)
|
voice_latents_chunks = gr.Number(label="Voice Chunks", precision=0, value=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_voices = gr.Button(value="Refresh Voice List")
|
refresh_voices = gr.Button(value="Refresh Voice List")
|
||||||
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
|
||||||
|
@ -704,7 +709,7 @@ def setup_gradio():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
recompute_voice_latents.click(compute_latents,
|
recompute_voice_latents.click(compute_latents_proxy,
|
||||||
inputs=[
|
inputs=[
|
||||||
voice,
|
voice,
|
||||||
voice_latents_chunks,
|
voice_latents_chunks,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user