@ -31,7 +31,7 @@ import pandas as pd
from datetime import datetime
from datetime import timedelta
from tortoise . api import TextToSpeech , MODELS , get_model_path
from tortoise . api import TextToSpeech , MODELS , get_model_path , pad_or_truncate
from tortoise . utils . audio import load_audio , load_voice , load_voices , get_voice_dir
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name
@ -89,6 +89,8 @@ def generate(
if tts_loading :
raise Exception ( " TTS is still initializing... " )
load_tts ( )
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
do_gc ( )
@ -121,17 +123,8 @@ def generate(
voice_samples , conditioning_latents = load_voice ( voice )
if voice_samples and len ( voice_samples ) > 0 :
conditioning_latents = compute_latents ( voice = voice , voice_samples = voice_samples , voice_latents_chunks = voice_latents_chunks )
sample_voice = torch . cat ( voice_samples , dim = - 1 ) . squeeze ( ) . cpu ( )
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , progress = progress , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents )
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
if voice != " microphone " :
if hasattr ( tts , ' autoregressive_model_hash ' ) :
torch . save ( conditioning_latents , f ' { get_voice_dir ( ) } / { voice } /cond_latents_ { tts . autoregressive_model_hash [ : 8 ] } .pth ' )
else :
torch . save ( conditioning_latents , f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth ' )
voice_samples = None
else :
if conditioning_latents is not None :
@ -551,6 +544,10 @@ def update_baseline_for_latents_chunks( voice ):
if not os . path . isdir ( path ) :
return 1
dataset_file = f ' ./training/ { voice } /train.txt '
if os . path . exists ( dataset_file ) :
return 0 # 0 will leverage using the LJspeech dataset for computing latents
files = os . listdir ( path )
total = 0
@ -565,11 +562,13 @@ def update_baseline_for_latents_chunks( voice ):
total_duration + = duration
total = total + 1
# brain too fried to figure out a better way
if args . autocalculate_voice_chunk_duration_size == 0 :
return int ( total_duration / total ) if total > 0 else 1
return int ( total_duration / args . autocalculate_voice_chunk_duration_size ) if total_duration > 0 else 1
def compute_latents ( voice , voice_latents_chunks , progress = gr . Progress ( track_tqdm = True ) ) :
def compute_latents ( voice = None , voice_samples = None , voice_latents_chunks = 0 , progress = None ) :
global tts
global args
@ -581,12 +580,42 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
raise Exception ( " TTS is still initializing... " )
load_tts ( )
voice_samples , conditioning_latents = load_voice ( voice , load_latents = False )
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
if voice :
load_from_dataset = voice_latents_chunks == 0
if load_from_dataset :
dataset_path = f ' ./training/ { voice } /train.txt '
if not os . path . exists ( dataset_path ) :
load_from_dataset = False
else :
with open ( dataset_path , ' r ' , encoding = " utf-8 " ) as f :
lines = f . readlines ( )
print ( " Leveraging LJSpeech dataset for computing latents " )
voice_samples = [ ]
max_length = 0
for line in lines :
filename = f ' ./training/ { voice } / { line . split ( " | " ) [ 0 ] } '
waveform = load_audio ( filename , 22050 )
max_length = max ( max_length , waveform . shape [ - 1 ] )
voice_samples . append ( waveform )
for i in range ( len ( voice_samples ) ) :
voice_samples [ i ] = pad_or_truncate ( voice_samples [ i ] , max_length )
voice_latents_chunks = len ( voice_samples )
if not load_from_dataset :
voice_samples , _ = load_voice ( voice , load_latents = False )
if voice_samples is None :
return
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , progress = progress , slices = voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latents )
conditioning_latents = tts . get_conditioning_latents ( voice_samples , return_mels = not args . latents_lean_and_mean , slices= voice_latents_chunks , force_cpu = args . force_cpu_for_conditioning_latent s, progress = progres s)
if len ( conditioning_latents ) == 4 :
conditioning_latents = ( conditioning_latents [ 0 ] , conditioning_latents [ 1 ] , conditioning_latents [ 2 ] , None )
@ -596,7 +625,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
else :
torch . save ( conditioning_latents , f ' { get_voice_dir ( ) } / { voice } /cond_latents.pth ' )
return voice
return conditioning_latents
# superfluous, but it cleans up some things
class TrainingState ( ) :
@ -1847,6 +1876,10 @@ def update_autoregressive_model(autoregressive_model_path):
if tts_loading :
raise Exception ( " TTS is still initializing... " )
return
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
print ( f " Loading model: { autoregressive_model_path } " )
tts . load_autoregressive_model ( autoregressive_model_path )
@ -1867,6 +1900,9 @@ def update_vocoder_model(vocoder_model):
raise Exception ( " TTS is still initializing... " )
return
if hasattr ( tts , " loading " ) and tts . loading :
raise Exception ( " TTS is still initializing... " )
print ( f " Loading model: { vocoder_model } " )
tts . load_vocoder_model ( vocoder_model )
print ( f " Loaded model: { tts . vocoder_model } " )