@ -62,6 +62,7 @@ def setup_args():
' defer-tts-load ' : False ,
' device-override ' : None ,
' whisper-model ' : " base " ,
' autoregressive-model ' : None ,
' concurrency-count ' : 2 ,
' output-sample-rate ' : 44100 ,
' output-volume ' : 1 ,
@ -87,6 +88,7 @@ def setup_args():
parser . add_argument ( " --defer-tts-load " , default = default_arguments [ ' defer-tts-load ' ] , action = ' store_true ' , help = " Defers loading TTS model " )
parser . add_argument ( " --device-override " , default = default_arguments [ ' device-override ' ] , help = " A device string to override pass through Torch " )
parser . add_argument ( " --whisper-model " , default = default_arguments [ ' whisper-model ' ] , help = " Specifies which whisper model to use for transcription. " )
parser . add_argument ( " --autoregressive-model " , default = default_arguments [ ' autoregressive-model ' ] , help = " Specifies which autoregressive model to use for sampling. " )
parser . add_argument ( " --sample-batch-size " , default = default_arguments [ ' sample-batch-size ' ] , type = int , help = " Sets how many batches to use during the autoregressive samples pass " )
parser . add_argument ( " --concurrency-count " , type = int , default = default_arguments [ ' concurrency-count ' ] , help = " How many Gradio events to process at once " )
parser . add_argument ( " --output-sample-rate " , type = int , default = default_arguments [ ' output-sample-rate ' ] , help = " Sample rate to resample the output to (from 24KHz) " )
@ -151,10 +153,8 @@ def generate(
global args
global tts
try :
tts
except NameError :
raise Exception ( " TTS is still initializing... " )
if not tts :
raise Exception ( " TTS is uninitialized or still initializing... " )
if voice != " microphone " :
voices = [ voice ]
@ -493,7 +493,7 @@ def setup_tortoise(restart=False):
tts = None
print ( " Initializating TorToiSe... " )
tts = TextToSpeech ( minor_optimizations = not args . low_vram )
tts = TextToSpeech ( minor_optimizations = not args . low_vram , autoregressive_model_path = args . autoregressive_model )
get_model_path ( ' dvae.pth ' )
print ( " TorToiSe initialized, ready for generation. " )
return tts
@ -720,7 +720,47 @@ def get_voice_list(dir=get_voice_dir()):
os . makedirs ( dir , exist_ok = True )
return sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 ] ) + [ " microphone " , " random " ]
def export_exec_settings ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , device_override , sample_batch_size , concurrency_count , output_sample_rate , output_volume , whisper_model ) :
def get_autoregressive_models ( dir = " ./models/finetuned/ " ) :
os . makedirs ( dir , exist_ok = True )
return [ get_model_path ( ' autoregressive.pth ' ) ] + sorted ( [ d for d in os . listdir ( dir ) if os . path . isdir ( os . path . join ( dir , d ) ) and len ( os . listdir ( os . path . join ( dir , d ) ) ) > 0 ] )
def update_autoregressive_model ( path_name ) :
global tts
if not tts :
raise Exception ( " TTS is uninitialized or still initializing... " )
print ( f " Loading model: { path_name } " )
if hasattr ( tts , ' load_autoregressive_model ' ) and tts . load_autoregressive_model ( path_name ) :
args . autoregressive_model = path_name
save_args_settings ( )
# polyfill in case a user did NOT update the packages
else :
from tortoise . models . autoregressive import UnifiedVoice
previous_path = tts . autoregressive_model_path
tts . autoregressive_model_path = path_name if path_name and os . path . exists ( path_name ) else get_model_path ( ' autoregressive.pth ' )
del tts . autoregressive
tts . autoregressive = UnifiedVoice ( max_mel_tokens = 604 , max_text_tokens = 402 , max_conditioning_inputs = 2 , layers = 30 ,
model_dim = 1024 ,
heads = 16 , number_text_tokens = 255 , start_text_token = 255 , checkpointing = False ,
train_solo_embeddings = False ) . cpu ( ) . eval ( )
tts . autoregressive . load_state_dict ( torch . load ( tts . autoregressive_model_path ) )
tts . autoregressive . post_init_gpt2_config ( kv_cache = tts . use_kv_cache )
if tts . preloaded_tensors :
tts . autoregressive = tts . autoregressive . to ( tts . device )
if previous_path != tts . autoregressive_model_path :
args . autoregressive_model = path_name
save_args_settings ( )
print ( f " Loaded model: { tts . autoregressive_model_path } " )
return path_name
def update_args ( listen , share , check_for_updates , models_from_local_only , low_vram , embed_output_metadata , latents_lean_and_mean , voice_fixer , voice_fixer_use_cuda , force_cpu_for_conditioning_latents , defer_tts_load , device_override , sample_batch_size , concurrency_count , output_sample_rate , output_volume ) :
global args
args . listen = listen
@ -731,7 +771,6 @@ def export_exec_settings( listen, share, check_for_updates, models_from_local_on
args . force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args . defer_tts_load = defer_tts_load
args . device_override = device_override
args . whisper_model = whisper_model
args . sample_batch_size = sample_batch_size
args . embed_output_metadata = embed_output_metadata
args . latents_lean_and_mean = latents_lean_and_mean
@ -741,6 +780,9 @@ def export_exec_settings( listen, share, check_for_updates, models_from_local_on
args . output_sample_rate = output_sample_rate
args . output_volume = output_volume
save_args_settings ( )
def save_args_settings ( ) :
settings = {
' listen ' : None if args . listen else args . listen ,
' share ' : args . share ,
@ -751,6 +793,7 @@ def export_exec_settings( listen, share, check_for_updates, models_from_local_on
' defer-tts-load ' : args . defer_tts_load ,
' device-override ' : args . device_override ,
' whisper-model ' : args . whisper_model ,
' autoregressive-model ' : args . autoregressive_model ,
' sample-batch-size ' : args . sample_batch_size ,
' embed-output-metadata ' : args . embed_output_metadata ,
' latents-lean-and-mean ' : args . latents_lean_and_mean ,