added VALL-E inference support (very rudimentary, gimped, but it will load a model trained on a config generated through the web UI)

This commit is contained in:
mrq 2023-03-31 03:26:00 +00:00
parent 9b01377667
commit 4744120be2
3 changed files with 426 additions and 116 deletions

View File

@ -1,61 +0,0 @@
{
"optimizer": {
"type": "AdamW",
"params": {
"lr": 2e-05,
"betas": [
0.9,
0.96
],
"eps": 1e-07,
"weight_decay": 0.01
}
},
"scheduler":{
"type":"WarmupLR",
"params":{
"warmup_min_lr":0,
"warmup_max_lr":2e-5,
"warmup_num_steps":100,
"warmup_type":"linear"
}
},
"fp16":{
"enabled":true,
"loss_scale":0,
"loss_scale_window":1000,
"initial_scale_power":16,
"hysteresis":2,
"min_loss_scale":1
},
"autotuning":{
"enabled":false,
"results_dir":"./config/autotune/results",
"exps_dir":"./config/autotune/exps",
"overwrite":false,
"metric":"throughput",
"start_profile_step":10,
"end_profile_step":20,
"fast":false,
"max_train_batch_size":32,
"mp_size":1,
"num_tuning_micro_batch_sizes":3,
"tuner_type":"model_based",
"tuner_early_stopping":5,
"tuner_num_trials":50,
"arg_mappings":{
"train_micro_batch_size_per_gpu":"--per_device_train_batch_size",
"gradient_accumulation_steps ":"--gradient_accumulation_steps"
}
},
"zero_optimization":{
"stage":0,
"reduce_bucket_size":"auto",
"contiguous_gradients":true,
"sub_group_size":1e8,
"stage3_prefetch_bucket_size":"auto",
"stage3_param_persistence_threshold":"auto",
"stage3_max_live_parameters":"auto",
"stage3_max_reuse_distance":"auto"
}
}

View File

@ -22,6 +22,7 @@ import psutil
import yaml
import hashlib
import string
import random
from tqdm import tqdm
import torch
@ -34,7 +35,7 @@ import pandas as pd
from datetime import datetime
from datetime import timedelta
from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate
from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices
from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc
@ -68,6 +69,10 @@ try:
from vall_e.emb.qnt import encode as valle_quantize
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.inference import TTS as VALLE_TTS
import soundfile
VALLE_ENABLED = True
except Exception as e:
pass
@ -111,6 +116,12 @@ def resample( waveform, input_rate, output_rate=44100 ):
return RESAMPLERS[key]( waveform ), output_rate
def generate(**kwargs):
if args.tts_backend == "tortoise":
return generate_tortoise(**kwargs)
if args.tts_backend == "vall-e":
return generate_valle(**kwargs)
def generate_valle(**kwargs):
parameters = {}
parameters.update(kwargs)
@ -140,7 +151,298 @@ def generate(**kwargs):
do_gc()
voice_samples = None
conditioning_latents =None
conditioning_latents = None
sample_voice = None
voice_cache = {}
def fetch_voice( voice ):
voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
return files
# return random.choice(files)
def get_settings( override=None ):
settings = {
'ar_temp': float(parameters['temperature']),
'nar_temp': float(parameters['temperature']),
'max_ar_samples': parameters['num_autoregressive_samples'],
}
# could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None:
if 'voice' in override:
selected_voice = override['voice']
for k in override:
if k not in settings:
continue
settings[k] = override[k]
settings['reference'] = fetch_voice(voice=selected_voice)
return settings
if not parameters['delimiter']:
parameters['delimiter'] = "\n"
elif parameters['delimiter'] == "\\n":
parameters['delimiter'] = "\n"
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
texts = parameters['text'].split(parameters['delimiter'])
else:
texts = split_and_recombine_text(parameters['text'])
full_start_time = time.time()
outdir = f"{args.results_folder}/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
if match and len(match) > 0:
key = int(match[0])
idx_cache[key] = True
if len(idx_cache) > 0:
keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1
idx = pad(idx, 4)
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if combined:
name = f"{name}_combined"
elif len(texts) > 1:
name = f"{name}_{line}"
if parameters['candidates'] > 1:
name = f"{name}_{candidate}"
return name
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['progress'] = None
del info['progress']
if info['delimiter'] == "\n":
info['delimiter'] = "\\n"
if settings is not None:
for k in settings:
if k in info:
info[k] = settings[k]
return info
INFERENCING = True
for line, cut_text in enumerate(texts):
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
print(f"{progress.msg_prefix} Generating line: {cut_text}")
start_time = time.time()
# do setting editing
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
override = None
if match and len(match) > 0:
match = match[0]
try:
override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
reference = settings['reference']
settings.pop("reference")
gen = tts.inference(cut_text, reference, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
wav, sr = g
name = get_name(line=line, candidate=j)
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat()
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
audio_cache[name] = {
'audio': wav,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
}
del gen
do_gc()
INFERENCING = False
for k in audio_cache:
audio = audio_cache[k]['audio']
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
output_voices = []
for candidate in range(parameters['candidates']):
if len(texts) > 1:
audio_clips = []
for line in range(len(texts)):
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(candidate=candidate, combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[name] = {
'audio': audio,
'settings': get_info(voice=voice),
'output': True
}
else:
name = get_name(candidate=candidate)
audio_cache[name]['output'] = True
if args.voice_fixer:
if not voicefixer:
progress(0, "Loading voicefix...")
load_voicefixer()
try:
fixed_cache = {}
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
del audio_cache[name]['audio']
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
continue
path = f'{outdir}/{voice}_{name}.wav'
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
voicefixer.restore(
input=path,
output=fixed,
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
#mode=mode,
)
fixed_cache[f'{name}_fixed'] = {
'settings': audio_cache[name]['settings'],
'output': True
}
audio_cache[name]['output'] = False
for name in fixed_cache:
audio_cache[name] = fixed_cache[name]
except Exception as e:
print(e)
print("\nFailed to run Voicefixer")
for name in audio_cache:
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
if args.prune_nonfinal_outputs:
audio_cache[name]['pruned'] = True
os.remove(f'{outdir}/{voice}_{name}.wav')
continue
output_voices.append(f'{outdir}/{voice}_{name}.wav')
if not args.embed_output_metadata:
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
if args.embed_output_metadata:
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
continue
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
metadata.save()
if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = usedSeed
if 'latents' in info:
del info['latents']
os.makedirs('./config/', exist_ok=True)
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
stats = [
[ parameters['seed'], "{:.3f}".format(info['time']) ]
]
return (
sample_voice,
output_voices,
stats,
)
def generate_tortoise(**kwargs):
parameters = {}
parameters.update(kwargs)
voice = parameters['voice']
progress = parameters['progress'] if 'progress' in parameters else None
if parameters['seed'] == 0:
parameters['seed'] = None
usedSeed = parameters['seed']
global args
global tts
unload_whisper()
unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
load_tts()
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
do_gc()
voice_samples = None
conditioning_latents = None
sample_voice = None
voice_cache = {}
@ -295,11 +597,13 @@ def generate(**kwargs):
def get_info( voice, settings = None, latents = True ):
info = {}
info.update(parameters)
info['time'] = time.time()-full_start_time
info['time'] = time.time()-full_start_time
info['datetime'] = datetime.now().isoformat()
info['model'] = tts.autoregressive_model_path
info['model_hash'] = tts.autoregressive_model_hash
info['progress'] = None
del info['progress']
@ -381,9 +685,10 @@ def generate(**kwargs):
settings['text'] = cut_text
settings['time'] = run_time
settings['datetime'] = datetime.now().isoformat(),
settings['model'] = tts.autoregressive_model_path
settings['model_hash'] = tts.autoregressive_model_hash
settings['datetime'] = datetime.now().isoformat()
if args.tts_backend == "tortoise":
settings['model'] = tts.autoregressive_model_path
settings['model_hash'] = tts.autoregressive_model_hash
audio_cache[name] = {
'audio': audio,
@ -745,8 +1050,8 @@ class TrainingState():
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate
epoch_rate = self.it_rates / self.it * self.steps
if epoch_rate > 0:
if self.it_rates > 0 and self.it * self.steps > 0:
epoch_rate = self.it_rates / self.it * self.steps
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
try:
@ -925,6 +1230,7 @@ class TrainingState():
self.it_rates = 0
unq = {}
averager = None
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
@ -941,16 +1247,18 @@ class TrainingState():
if line.find('Training Metrics:') >= 0:
split = line.split("Training Metrics:")[-1]
data = json.loads(split)
data['mode'] = "training"
name = "train"
mode = "training"
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
data['mode'] = "validation"
if "it" not in data:
data['it'] = it
if "epoch" not in data:
data['epoch'] = epoch
name = data['name'] if 'name' in data else "val"
mode = "validation"
else:
continue
@ -960,14 +1268,39 @@ class TrainingState():
it = data['it']
epoch = data['epoch']
# this method should have it at least
unq[f'{it}_{name}'] = data
if args.tts_backend == "vall-e":
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
averager = {
'key': f'{it}_{name}',
'mode': mode,
"metrics": {}
}
for k in data:
if data[k] is None:
continue
averager['metrics'][k] = [ data[k] ]
else:
for k in data:
if data[k] is None:
continue
averager['metrics'][k].append( data[k] )
unq[f'{it}_{mode}_{name}'] = averager
else:
unq[f'{it}_{mode}_{name}'] = data
if update and it <= self.last_info_check_at:
continue
for it in unq:
self.parse_metrics(unq[it])
if args.tts_backend == "vall-e":
stats = unq[it]
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
data['mode'] = stats
data['steps'] = len(stats['metrics']['it'])
else:
data = unq[it]
self.parse_metrics(data)
self.last_info_check_at = highest_step
@ -1087,7 +1420,8 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
# ensure we have the dvae.pth
get_model_path('dvae.pth')
if args.tts_backend == "tortoise":
get_model_path('dvae.pth')
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
torch.multiprocessing.freeze_support()
@ -2086,6 +2420,8 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
res = res + defaults
return res
def get_valle_models(dir="./training/"):
return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ]
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False):
os.makedirs(dir, exist_ok=True)
@ -2268,6 +2604,8 @@ def setup_args():
'tokenizer-json': None,
'phonemizer-backend': 'espeak',
'valle-model': None,
'whisper-backend': 'openai/whisper',
'whisper-model': "base",
@ -2319,6 +2657,8 @@ def setup_args():
parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.")
parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.")
parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)")
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX")
@ -2389,6 +2729,8 @@ def get_default_settings( hypenated=True ):
'tokenizer-json': args.tokenizer_json,
'phonemizer-backend': args.phonemizer_backend,
'valle-model': args.valle_model,
'whisper-backend': args.whisper_backend,
'whisper-model': args.whisper_model,
@ -2439,6 +2781,8 @@ def update_args( **kwargs ):
args.tokenizer_json = settings['tokenizer_json']
args.phonemizer_backend = settings['phonemizer_backend']
args.valle_model = settings['valle_model']
args.whisper_backend = settings['whisper_backend']
args.whisper_model = settings['whisper_model']
@ -2553,50 +2897,61 @@ def version_check_tts( min_version ):
return True
return False
def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None ):
def load_tts( restart=False,
# TorToiSe configs
autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None,
# VALL-E configs
valle_model=None,
):
global args
global tts
if restart:
unload_tts()
if autoregressive_model:
args.autoregressive_model = autoregressive_model
else:
autoregressive_model = args.autoregressive_model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model()
if diffusion_model:
args.diffusion_model = diffusion_model
else:
diffusion_model = args.diffusion_model
if vocoder_model:
args.vocoder_model = vocoder_model
else:
vocoder_model = args.vocoder_model
if tokenizer_json:
args.tokenizer_json = tokenizer_json
else:
tokenizer_json = args.tokenizer_json
if get_device_name() == "cpu":
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
tts_loading = True
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})")
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
if args.tts_backend == "tortoise":
if autoregressive_model:
args.autoregressive_model = autoregressive_model
else:
autoregressive_model = args.autoregressive_model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model()
if diffusion_model:
args.diffusion_model = diffusion_model
else:
diffusion_model = args.diffusion_model
if vocoder_model:
args.vocoder_model = vocoder_model
else:
vocoder_model = args.vocoder_model
if tokenizer_json:
args.tokenizer_json = tokenizer_json
else:
tokenizer_json = args.tokenizer_json
if get_device_name() == "cpu":
print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.")
print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})")
tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches)
elif args.tts_backend == "vall-e":
if valle_model:
args.valle_model = valle_model
else:
valle_model = args.valle_model
print(f"Loading VALL-E... (Config: {valle_model})")
tts = VALLE_TTS(config=args.valle_model)
print("Loaded TTS, ready for generation.")
tts_loading = False
get_model_path('dvae.pth')
print("Loaded TorToiSe, ready for generation.")
return tts
setup_tortoise = load_tts
def unload_tts():
global tts
@ -2643,6 +2998,9 @@ def deduce_autoregressive_model(voice=None):
return get_model_path('autoregressive.pth')
def update_autoregressive_model(autoregressive_model_path):
if args.tts_backend != "tortoise":
raise f"Unsupported backend: {args.tts_backend}"
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
if match:
autoregressive_model_path = match[0]
@ -2677,6 +3035,9 @@ def update_autoregressive_model(autoregressive_model_path):
return autoregressive_model_path
def update_diffusion_model(diffusion_model_path):
if args.tts_backend != "tortoise":
raise f"Unsupported backend: {args.tts_backend}"
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path)
if match:
diffusion_model_path = match[0]
@ -2711,6 +3072,9 @@ def update_diffusion_model(diffusion_model_path):
return diffusion_model_path
def update_vocoder_model(vocoder_model):
if args.tts_backend != "tortoise":
raise f"Unsupported backend: {args.tts_backend}"
args.vocoder_model = vocoder_model
save_args_settings()
print(f'Stored vocoder model to settings: {vocoder_model}')
@ -2733,6 +3097,9 @@ def update_vocoder_model(vocoder_model):
return vocoder_model
def update_tokenizer(tokenizer_json):
if args.tts_backend != "tortoise":
raise f"Unsupported backend: {args.tts_backend}"
args.tokenizer_json = tokenizer_json
save_args_settings()
print(f'Stored tokenizer to settings: {tokenizer_json}')

View File

@ -315,6 +315,8 @@ def setup_gradio():
voice_list = get_voice_list()
result_voices = get_voice_list(args.results_folder)
valle_models = get_valle_models()
autoregressive_models = get_autoregressive_models()
diffusion_models = get_diffusion_models()
tokenizer_jsons = get_tokenizer_jsons()
@ -337,11 +339,11 @@ def setup_gradio():
with gr.Column():
GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True )
GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True, visible=args.tts_backend=="tortoise" )
GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False)
GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0)
GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0, visible=args.tts_backend=="tortoise")
with gr.Row():
refresh_voices = gr.Button(value="Refresh Voice List")
recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
@ -357,17 +359,17 @@ def setup_gradio():
outputs=GENERATE_SETTINGS["mic_audio"],
)
with gr.Column():
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations")
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
show_experimental_settings = gr.Checkbox(label="Show Experimental Settings", visible=args.tts_backend=="tortoise")
reset_generate_settings_button = gr.Button(value="Reset to Default")
with gr.Column(visible=False) as col:
experimental_column = col
@ -606,10 +608,12 @@ def setup_gradio():
EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override)
EXEC_SETTINGS['results_folder'] = gr.Textbox(label="Results Folder", value=args.results_folder)
with gr.Column():
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
with gr.Column(visible=args.tts_backend=="vall-e"):
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
with gr.Column(visible=args.tts_backend=="tortoise"):
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
EXEC_SETTINGS['diffusion_model'] = gr.Dropdown(choices=diffusion_models, label="Diffusion Model", value=args.diffusion_model if args.diffusion_model else diffusion_models[0])
EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])