moved (actually not working) setting to use BigVGAN to a dropdown to select between vocoders (for when slotting in future ones), and ability to load a new vocoder while TTS is loaded

This commit is contained in:
mrq 2023-03-07 02:45:22 +00:00
parent e731b9ba84
commit 0f0b394445
3 changed files with 38 additions and 33 deletions

View File

@ -42,6 +42,8 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band'] #, 'bigvgan_24khz_100band']
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
args = None args = None
@ -1539,7 +1541,7 @@ def setup_args():
'defer-tts-load': False, 'defer-tts-load': False,
'device-override': None, 'device-override': None,
'prune-nonfinal-outputs': True, 'prune-nonfinal-outputs': True,
'use-bigvgan-vocoder': True, 'vocoder-model': VOCODERS[-1],
'concurrency-count': 2, 'concurrency-count': 2,
'autocalculate-voice-chunk-duration-size': 0, 'autocalculate-voice-chunk-duration-size': 0,
'output-sample-rate': 44100, 'output-sample-rate': 44100,
@ -1576,7 +1578,7 @@ def setup_args():
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)") parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder") parser.add_argument("--vocoder-model", default=default_arguments['vocoder-model'], action='store_true', help="Specifies with vocoder to use")
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
@ -1620,7 +1622,7 @@ def setup_args():
return args return args
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ): def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, vocoder_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ):
global args global args
args.listen = listen args.listen = listen
@ -1631,7 +1633,6 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.defer_tts_load = defer_tts_load args.defer_tts_load = defer_tts_load
args.prune_nonfinal_outputs = prune_nonfinal_outputs args.prune_nonfinal_outputs = prune_nonfinal_outputs
args.use_bigvgan_vocoder = use_bigvgan_vocoder
args.device_override = device_override args.device_override = device_override
args.sample_batch_size = sample_batch_size args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata args.embed_output_metadata = embed_output_metadata
@ -1644,6 +1645,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
args.output_volume = output_volume args.output_volume = output_volume
args.autoregressive_model = autoregressive_model args.autoregressive_model = autoregressive_model
args.vocoder_model = vocoder_model
args.whisper_backend = whisper_backend args.whisper_backend = whisper_backend
args.whisper_model = whisper_model args.whisper_model = whisper_model
@ -1663,7 +1665,6 @@ def save_args_settings():
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents, 'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'defer-tts-load': args.defer_tts_load, 'defer-tts-load': args.defer_tts_load,
'prune-nonfinal-outputs': args.prune_nonfinal_outputs, 'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
'use-bigvgan-vocoder': args.use_bigvgan_vocoder,
'device-override': args.device_override, 'device-override': args.device_override,
'sample-batch-size': args.sample_batch_size, 'sample-batch-size': args.sample_batch_size,
'embed-output-metadata': args.embed_output_metadata, 'embed-output-metadata': args.embed_output_metadata,
@ -1676,6 +1677,7 @@ def save_args_settings():
'output-volume': args.output_volume, 'output-volume': args.output_volume,
'autoregressive-model': args.autoregressive_model, 'autoregressive-model': args.autoregressive_model,
'vocoder-model': args.vocoder_model,
'whisper-backend': args.whisper_backend, 'whisper-backend': args.whisper_backend,
'whisper-model': args.whisper_model, 'whisper-model': args.whisper_model,
@ -1791,11 +1793,11 @@ def load_tts( restart=False, model=None ):
if model: if model:
args.autoregressive_model = model args.autoregressive_model = model
print(f"Loading TorToiSe... (using model: {args.autoregressive_model})") print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
tts_loading = True tts_loading = True
try: try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
except Exception as e: except Exception as e:
tts = TextToSpeech(minor_optimizations=not args.low_vram) tts = TextToSpeech(minor_optimizations=not args.low_vram)
load_autoregressive_model(args.autoregressive_model) load_autoregressive_model(args.autoregressive_model)
@ -1843,35 +1845,32 @@ def update_autoregressive_model(autoregressive_model_path):
return return
print(f"Loading model: {autoregressive_model_path}") print(f"Loading model: {autoregressive_model_path}")
if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(autoregressive_model_path):
tts.load_autoregressive_model(autoregressive_model_path) tts.load_autoregressive_model(autoregressive_model_path)
# polyfill in case a user did NOT update the packages
# this shouldn't happen anymore, as I just clone mrq/tortoise-tts, and inject it into sys.path
else:
from tortoise.models.autoregressive import UnifiedVoice
tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir)
del tts.autoregressive
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
tts.autoregressive.load_state_dict(torch.load(tts.autoregressive_model_path))
tts.autoregressive.post_init_gpt2_config(kv_cache=tts.use_kv_cache)
if tts.preloaded_tensors:
tts.autoregressive = tts.autoregressive.to(tts.device)
if not hasattr(tts, 'autoregressive_model_hash'):
tts.autoregressive_model_hash = hash_file(autoregressive_model_path)
print(f"Loaded model: {tts.autoregressive_model_path}") print(f"Loaded model: {tts.autoregressive_model_path}")
do_gc() do_gc()
return autoregressive_model_path return autoregressive_model_path
def update_vocoder_model(vocoder_model):
args.vocoder_model = vocoder_model
save_args_settings()
print(f'Stored vocoder model to settings: {vocoder_model}')
global tts
if not tts:
if tts_loading:
raise Exception("TTS is still initializing...")
return
print(f"Loading model: {vocoder_model}")
tts.load_vocoder_model(vocoder_model)
print(f"Loaded model: {tts.vocoder_model}")
do_gc()
return vocoder_model
def load_voicefixer(restart=False): def load_voicefixer(restart=False):
global voicefixer global voicefixer

View File

@ -577,7 +577,6 @@ def setup_gradio():
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs), gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
gr.Checkbox(label="Use BigVGAN Vocoder", value=args.use_bigvgan_vocoder),
gr.Textbox(label="Device Override", value=args.device_override), gr.Textbox(label="Device Override", value=args.device_override),
] ]
with gr.Column(): with gr.Column():
@ -590,10 +589,11 @@ def setup_gradio():
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
vocoder_models = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])
whisper_backend = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) whisper_backend = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_backend, whisper_model_dropdown, training_halfp, training_bnb ] exec_inputs = exec_inputs + [ autoregressive_model_dropdown, vocoder_models, whisper_backend, whisper_model_dropdown, training_halfp, training_bnb ]
with gr.Row(): with gr.Row():
autoregressive_models_update_button = gr.Button(value="Refresh Model List") autoregressive_models_update_button = gr.Button(value="Refresh Model List")
@ -626,6 +626,12 @@ def setup_gradio():
outputs=None outputs=None
) )
vocoder_models.change(
fn=update_vocoder_model,
inputs=vocoder_models,
outputs=None
)
input_settings = [ input_settings = [
text, text,
delimiter, delimiter,

@ -1 +1 @@
Subproject commit 6fcd8c604f066e4e346da522bd14e6670395025f Subproject commit e2db36af602297501132f7f68331755f5904825a