From 4744120be2731b613a853a30d49adce98411ae61 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 31 Mar 2023 03:26:00 +0000 Subject: [PATCH] added VALL-E inference support (very rudimentary, gimped, but it will load a model trained on a config generated through the web UI) --- config/ds_config.json | 61 ------ src/utils.py | 463 +++++++++++++++++++++++++++++++++++++----- src/webui.py | 18 +- 3 files changed, 426 insertions(+), 116 deletions(-) delete mode 100755 config/ds_config.json diff --git a/config/ds_config.json b/config/ds_config.json deleted file mode 100755 index 47827ca..0000000 --- a/config/ds_config.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "optimizer": { - "type": "AdamW", - "params": { - "lr": 2e-05, - "betas": [ - 0.9, - 0.96 - ], - "eps": 1e-07, - "weight_decay": 0.01 - } - }, - "scheduler":{ - "type":"WarmupLR", - "params":{ - "warmup_min_lr":0, - "warmup_max_lr":2e-5, - "warmup_num_steps":100, - "warmup_type":"linear" - } - }, - "fp16":{ - "enabled":true, - "loss_scale":0, - "loss_scale_window":1000, - "initial_scale_power":16, - "hysteresis":2, - "min_loss_scale":1 - }, - "autotuning":{ - "enabled":false, - "results_dir":"./config/autotune/results", - "exps_dir":"./config/autotune/exps", - "overwrite":false, - "metric":"throughput", - "start_profile_step":10, - "end_profile_step":20, - "fast":false, - "max_train_batch_size":32, - "mp_size":1, - "num_tuning_micro_batch_sizes":3, - "tuner_type":"model_based", - "tuner_early_stopping":5, - "tuner_num_trials":50, - "arg_mappings":{ - "train_micro_batch_size_per_gpu":"--per_device_train_batch_size", - "gradient_accumulation_steps ":"--gradient_accumulation_steps" - } - }, - "zero_optimization":{ - "stage":0, - "reduce_bucket_size":"auto", - "contiguous_gradients":true, - "sub_group_size":1e8, - "stage3_prefetch_bucket_size":"auto", - "stage3_param_persistence_threshold":"auto", - "stage3_max_live_parameters":"auto", - "stage3_max_reuse_distance":"auto" - } -} \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 22d848d..ae6fb89 100755 --- a/src/utils.py +++ b/src/utils.py @@ -22,6 +22,7 @@ import psutil import yaml import hashlib import string +import random from tqdm import tqdm import torch @@ -34,7 +35,7 @@ import pandas as pd from datetime import datetime from datetime import timedelta -from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate +from tortoise.api import TextToSpeech as TorToise_TTS, MODELS, get_model_path, pad_or_truncate from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, get_device_batch_size, do_gc @@ -68,6 +69,10 @@ try: from vall_e.emb.qnt import encode as valle_quantize from vall_e.emb.g2p import encode as valle_phonemize + from vall_e.inference import TTS as VALLE_TTS + + import soundfile + VALLE_ENABLED = True except Exception as e: pass @@ -111,6 +116,12 @@ def resample( waveform, input_rate, output_rate=44100 ): return RESAMPLERS[key]( waveform ), output_rate def generate(**kwargs): + if args.tts_backend == "tortoise": + return generate_tortoise(**kwargs) + if args.tts_backend == "vall-e": + return generate_valle(**kwargs) + +def generate_valle(**kwargs): parameters = {} parameters.update(kwargs) @@ -140,7 +151,298 @@ def generate(**kwargs): do_gc() voice_samples = None - conditioning_latents =None + conditioning_latents = None + sample_voice = None + + voice_cache = {} + def fetch_voice( voice ): + voice_dir = f'./voices/{voice}/' + files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] + return files + # return random.choice(files) + + def get_settings( override=None ): + settings = { + 'ar_temp': float(parameters['temperature']), + 'nar_temp': float(parameters['temperature']), + 'max_ar_samples': parameters['num_autoregressive_samples'], + } + + # could be better to just do a ternary on everything above, but i am not a professional + selected_voice = voice + if override is not None: + if 'voice' in override: + selected_voice = override['voice'] + + for k in override: + if k not in settings: + continue + settings[k] = override[k] + + settings['reference'] = fetch_voice(voice=selected_voice) + return settings + + if not parameters['delimiter']: + parameters['delimiter'] = "\n" + elif parameters['delimiter'] == "\\n": + parameters['delimiter'] = "\n" + + if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']: + texts = parameters['text'].split(parameters['delimiter']) + else: + texts = split_and_recombine_text(parameters['text']) + + full_start_time = time.time() + + outdir = f"{args.results_folder}/{voice}/" + os.makedirs(outdir, exist_ok=True) + + audio_cache = {} + + volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None + + idx = 0 + idx_cache = {} + for i, file in enumerate(os.listdir(outdir)): + filename = os.path.basename(file) + extension = os.path.splitext(filename)[1] + if extension != ".json" and extension != ".wav": + continue + match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) + if match and len(match) > 0: + key = int(match[0]) + idx_cache[key] = True + + if len(idx_cache) > 0: + keys = sorted(list(idx_cache.keys())) + idx = keys[-1] + 1 + + idx = pad(idx, 4) + + def get_name(line=0, candidate=0, combined=False): + name = f"{idx}" + if combined: + name = f"{name}_combined" + elif len(texts) > 1: + name = f"{name}_{line}" + if parameters['candidates'] > 1: + name = f"{name}_{candidate}" + return name + + def get_info( voice, settings = None, latents = True ): + info = {} + info.update(parameters) + + info['time'] = time.time()-full_start_time + info['datetime'] = datetime.now().isoformat() + + info['progress'] = None + del info['progress'] + + if info['delimiter'] == "\n": + info['delimiter'] = "\\n" + + if settings is not None: + for k in settings: + if k in info: + info[k] = settings[k] + return info + + INFERENCING = True + for line, cut_text in enumerate(texts): + progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' + print(f"{progress.msg_prefix} Generating line: {cut_text}") + start_time = time.time() + + # do setting editing + match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) + override = None + if match and len(match) > 0: + match = match[0] + try: + override = json.loads(match[0]) + cut_text = match[1].strip() + except Exception as e: + raise Exception("Prompt settings editing requested, but received invalid JSON") + + settings = get_settings( override=override ) + reference = settings['reference'] + settings.pop("reference") + + gen = tts.inference(cut_text, reference, **settings ) + + run_time = time.time()-start_time + print(f"Generating line took {run_time} seconds") + + if not isinstance(gen, list): + gen = [gen] + + for j, g in enumerate(gen): + wav, sr = g + name = get_name(line=line, candidate=j) + + settings['text'] = cut_text + settings['time'] = run_time + settings['datetime'] = datetime.now().isoformat() + + # save here in case some error happens mid-batch + #torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr) + soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr) + wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav') + + audio_cache[name] = { + 'audio': wav, + 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) + } + + del gen + do_gc() + INFERENCING = False + + for k in audio_cache: + audio = audio_cache[k]['audio'] + + audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) + if volume_adjust is not None: + audio = volume_adjust(audio) + + audio_cache[k]['audio'] = audio + torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) + + output_voices = [] + for candidate in range(parameters['candidates']): + if len(texts) > 1: + audio_clips = [] + for line in range(len(texts)): + name = get_name(line=line, candidate=candidate) + audio = audio_cache[name]['audio'] + audio_clips.append(audio) + + name = get_name(candidate=candidate, combined=True) + audio = torch.cat(audio_clips, dim=-1) + torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) + + audio = audio.squeeze(0).cpu() + audio_cache[name] = { + 'audio': audio, + 'settings': get_info(voice=voice), + 'output': True + } + else: + name = get_name(candidate=candidate) + audio_cache[name]['output'] = True + + + if args.voice_fixer: + if not voicefixer: + progress(0, "Loading voicefix...") + load_voicefixer() + + try: + fixed_cache = {} + for name in progress.tqdm(audio_cache, desc="Running voicefix..."): + del audio_cache[name]['audio'] + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + continue + + path = f'{outdir}/{voice}_{name}.wav' + fixed = f'{outdir}/{voice}_{name}_fixed.wav' + voicefixer.restore( + input=path, + output=fixed, + cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda, + #mode=mode, + ) + + fixed_cache[f'{name}_fixed'] = { + 'settings': audio_cache[name]['settings'], + 'output': True + } + audio_cache[name]['output'] = False + + for name in fixed_cache: + audio_cache[name] = fixed_cache[name] + except Exception as e: + print(e) + print("\nFailed to run Voicefixer") + + for name in audio_cache: + if 'output' not in audio_cache[name] or not audio_cache[name]['output']: + if args.prune_nonfinal_outputs: + audio_cache[name]['pruned'] = True + os.remove(f'{outdir}/{voice}_{name}.wav') + continue + + output_voices.append(f'{outdir}/{voice}_{name}.wav') + + if not args.embed_output_metadata: + with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(audio_cache[name]['settings'], indent='\t') ) + + if args.embed_output_metadata: + for name in progress.tqdm(audio_cache, desc="Embedding metadata..."): + if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']: + continue + + metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav") + metadata['lyrics'] = json.dumps(audio_cache[name]['settings']) + metadata.save() + + if sample_voice is not None: + sample_voice = (tts.input_sample_rate, sample_voice.numpy()) + + info = get_info(voice=voice, latents=False) + print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") + + info['seed'] = usedSeed + if 'latents' in info: + del info['latents'] + + os.makedirs('./config/', exist_ok=True) + with open(f'./config/generate.json', 'w', encoding="utf-8") as f: + f.write(json.dumps(info, indent='\t') ) + + stats = [ + [ parameters['seed'], "{:.3f}".format(info['time']) ] + ] + + return ( + sample_voice, + output_voices, + stats, + ) + + + +def generate_tortoise(**kwargs): + parameters = {} + parameters.update(kwargs) + + voice = parameters['voice'] + progress = parameters['progress'] if 'progress' in parameters else None + if parameters['seed'] == 0: + parameters['seed'] = None + + usedSeed = parameters['seed'] + + global args + global tts + + unload_whisper() + unload_voicefixer() + + if not tts: + # should check if it's loading or unloaded, and load it if it's unloaded + if tts_loading: + raise Exception("TTS is still initializing...") + load_tts() + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") + + do_gc() + + voice_samples = None + conditioning_latents = None sample_voice = None voice_cache = {} @@ -295,11 +597,13 @@ def generate(**kwargs): def get_info( voice, settings = None, latents = True ): info = {} info.update(parameters) - info['time'] = time.time()-full_start_time + info['time'] = time.time()-full_start_time info['datetime'] = datetime.now().isoformat() + info['model'] = tts.autoregressive_model_path info['model_hash'] = tts.autoregressive_model_hash + info['progress'] = None del info['progress'] @@ -381,9 +685,10 @@ def generate(**kwargs): settings['text'] = cut_text settings['time'] = run_time - settings['datetime'] = datetime.now().isoformat(), - settings['model'] = tts.autoregressive_model_path - settings['model_hash'] = tts.autoregressive_model_hash + settings['datetime'] = datetime.now().isoformat() + if args.tts_backend == "tortoise": + settings['model'] = tts.autoregressive_model_path + settings['model_hash'] = tts.autoregressive_model_hash audio_cache[name] = { 'audio': audio, @@ -745,8 +1050,8 @@ class TrainingState(): self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rates += it_rate - epoch_rate = self.it_rates / self.it * self.steps - if epoch_rate > 0: + if self.it_rates > 0 and self.it * self.steps > 0: + epoch_rate = self.it_rates / self.it * self.steps self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch' try: @@ -925,6 +1230,7 @@ class TrainingState(): self.it_rates = 0 unq = {} + averager = None for log in logs: with open(log, 'r', encoding="utf-8") as f: @@ -941,16 +1247,18 @@ class TrainingState(): if line.find('Training Metrics:') >= 0: split = line.split("Training Metrics:")[-1] data = json.loads(split) - data['mode'] = "training" + name = "train" + mode = "training" elif line.find('Validation Metrics:') >= 0: data = json.loads(line.split("Validation Metrics:")[-1]) - data['mode'] = "validation" if "it" not in data: data['it'] = it if "epoch" not in data: data['epoch'] = epoch + name = data['name'] if 'name' in data else "val" + mode = "validation" else: continue @@ -960,14 +1268,39 @@ class TrainingState(): it = data['it'] epoch = data['epoch'] - # this method should have it at least - unq[f'{it}_{name}'] = data + if args.tts_backend == "vall-e": + if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode: + averager = { + 'key': f'{it}_{name}', + 'mode': mode, + "metrics": {} + } + for k in data: + if data[k] is None: + continue + averager['metrics'][k] = [ data[k] ] + else: + for k in data: + if data[k] is None: + continue + averager['metrics'][k].append( data[k] ) + + unq[f'{it}_{mode}_{name}'] = averager + else: + unq[f'{it}_{mode}_{name}'] = data if update and it <= self.last_info_check_at: continue for it in unq: - self.parse_metrics(unq[it]) + if args.tts_backend == "vall-e": + stats = unq[it] + data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()} + data['mode'] = stats + data['steps'] = len(stats['metrics']['it']) + else: + data = unq[it] + self.parse_metrics(data) self.last_info_check_at = highest_step @@ -1087,7 +1420,8 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress # ensure we have the dvae.pth - get_model_path('dvae.pth') + if args.tts_backend == "tortoise": + get_model_path('dvae.pth') # I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process torch.multiprocessing.freeze_support() @@ -2086,6 +2420,8 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False): res = res + defaults return res +def get_valle_models(dir="./training/"): + return [ f'{dir}/{d}/config.yaml' for d in os.listdir(dir) if os.path.exists(f'{dir}/{d}/config.yaml') ] def get_autoregressive_models(dir="./models/finetunes/", prefixed=False, auto=False): os.makedirs(dir, exist_ok=True) @@ -2268,6 +2604,8 @@ def setup_args(): 'tokenizer-json': None, 'phonemizer-backend': 'espeak', + + 'valle-model': None, 'whisper-backend': 'openai/whisper', 'whisper-model': "base", @@ -2319,6 +2657,8 @@ def setup_args(): parser.add_argument("--phonemizer-backend", default=default_arguments['phonemizer-backend'], help="Specifies which phonemizer backend to use.") + parser.add_argument("--valle-model", default=default_arguments['valle-model'], help="Specifies which VALL-E model to use for sampling.") + parser.add_argument("--whisper-backend", default=default_arguments['whisper-backend'], action='store_true', help="Picks which whisper backend to use (openai/whisper, lightmare/whispercpp)") parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") parser.add_argument("--whisper-batchsize", type=int, default=default_arguments['whisper-batchsize'], help="Specifies batch size for WhisperX") @@ -2389,6 +2729,8 @@ def get_default_settings( hypenated=True ): 'tokenizer-json': args.tokenizer_json, 'phonemizer-backend': args.phonemizer_backend, + + 'valle-model': args.valle_model, 'whisper-backend': args.whisper_backend, 'whisper-model': args.whisper_model, @@ -2439,6 +2781,8 @@ def update_args( **kwargs ): args.tokenizer_json = settings['tokenizer_json'] args.phonemizer_backend = settings['phonemizer_backend'] + + args.valle_model = settings['valle_model'] args.whisper_backend = settings['whisper_backend'] args.whisper_model = settings['whisper_model'] @@ -2553,50 +2897,61 @@ def version_check_tts( min_version ): return True return False -def load_tts( restart=False, autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None ): +def load_tts( restart=False, + # TorToiSe configs + autoregressive_model=None, diffusion_model=None, vocoder_model=None, tokenizer_json=None, + # VALL-E configs + valle_model=None, +): global args global tts if restart: unload_tts() - if autoregressive_model: - args.autoregressive_model = autoregressive_model - else: - autoregressive_model = args.autoregressive_model - - if autoregressive_model == "auto": - autoregressive_model = deduce_autoregressive_model() - - if diffusion_model: - args.diffusion_model = diffusion_model - else: - diffusion_model = args.diffusion_model - - if vocoder_model: - args.vocoder_model = vocoder_model - else: - vocoder_model = args.vocoder_model - - if tokenizer_json: - args.tokenizer_json = tokenizer_json - else: - tokenizer_json = args.tokenizer_json - - if get_device_name() == "cpu": - print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.") - tts_loading = True - print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {vocoder_model})") - tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches) + if args.tts_backend == "tortoise": + if autoregressive_model: + args.autoregressive_model = autoregressive_model + else: + autoregressive_model = args.autoregressive_model + + if autoregressive_model == "auto": + autoregressive_model = deduce_autoregressive_model() + + if diffusion_model: + args.diffusion_model = diffusion_model + else: + diffusion_model = args.diffusion_model + + if vocoder_model: + args.vocoder_model = vocoder_model + else: + vocoder_model = args.vocoder_model + + if tokenizer_json: + args.tokenizer_json = tokenizer_json + else: + tokenizer_json = args.tokenizer_json + + if get_device_name() == "cpu": + print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.") + + print(f"Loading TorToiSe... (AR: {autoregressive_model}, diffusion: {diffusion_model}, vocoder: {vocoder_model})") + tts = TorToise_TTS(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, diffusion_model_path=diffusion_model, vocoder_model=vocoder_model, tokenizer_json=tokenizer_json, unsqueeze_sample_batches=args.unsqueeze_sample_batches) + elif args.tts_backend == "vall-e": + if valle_model: + args.valle_model = valle_model + else: + valle_model = args.valle_model + + print(f"Loading VALL-E... (Config: {valle_model})") + tts = VALLE_TTS(config=args.valle_model) + + print("Loaded TTS, ready for generation.") tts_loading = False - - get_model_path('dvae.pth') - print("Loaded TorToiSe, ready for generation.") return tts -setup_tortoise = load_tts - def unload_tts(): global tts @@ -2643,6 +2998,9 @@ def deduce_autoregressive_model(voice=None): return get_model_path('autoregressive.pth') def update_autoregressive_model(autoregressive_model_path): + if args.tts_backend != "tortoise": + raise f"Unsupported backend: {args.tts_backend}" + match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) if match: autoregressive_model_path = match[0] @@ -2677,6 +3035,9 @@ def update_autoregressive_model(autoregressive_model_path): return autoregressive_model_path def update_diffusion_model(diffusion_model_path): + if args.tts_backend != "tortoise": + raise f"Unsupported backend: {args.tts_backend}" + match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', diffusion_model_path) if match: diffusion_model_path = match[0] @@ -2711,6 +3072,9 @@ def update_diffusion_model(diffusion_model_path): return diffusion_model_path def update_vocoder_model(vocoder_model): + if args.tts_backend != "tortoise": + raise f"Unsupported backend: {args.tts_backend}" + args.vocoder_model = vocoder_model save_args_settings() print(f'Stored vocoder model to settings: {vocoder_model}') @@ -2733,6 +3097,9 @@ def update_vocoder_model(vocoder_model): return vocoder_model def update_tokenizer(tokenizer_json): + if args.tts_backend != "tortoise": + raise f"Unsupported backend: {args.tts_backend}" + args.tokenizer_json = tokenizer_json save_args_settings() print(f'Stored tokenizer to settings: {tokenizer_json}') diff --git a/src/webui.py b/src/webui.py index 7824d57..7a6a03f 100755 --- a/src/webui.py +++ b/src/webui.py @@ -315,6 +315,8 @@ def setup_gradio(): voice_list = get_voice_list() result_voices = get_voice_list(args.results_folder) + valle_models = get_valle_models() + autoregressive_models = get_autoregressive_models() diffusion_models = get_diffusion_models() tokenizer_jsons = get_tokenizer_jsons() @@ -337,11 +339,11 @@ def setup_gradio(): with gr.Column(): GENERATE_SETTINGS["delimiter"] = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") - GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True ) + GENERATE_SETTINGS["emotion"] = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True, visible=args.tts_backend=="tortoise" ) GENERATE_SETTINGS["prompt"] = gr.Textbox(lines=1, label="Custom Emotion", visible=False) GENERATE_SETTINGS["voice"] = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit GENERATE_SETTINGS["mic_audio"] = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False ) - GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0) + GENERATE_SETTINGS["voice_latents_chunks"] = gr.Number(label="Voice Chunks", precision=0, value=0, visible=args.tts_backend=="tortoise") with gr.Row(): refresh_voices = gr.Button(value="Refresh Voice List") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") @@ -357,17 +359,17 @@ def setup_gradio(): outputs=GENERATE_SETTINGS["mic_audio"], ) with gr.Column(): - GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates") + GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise") GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed") preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" ) GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples") - GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations") + GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise") GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature") - show_experimental_settings = gr.Checkbox(label="Show Experimental Settings") + show_experimental_settings = gr.Checkbox(label="Show Experimental Settings", visible=args.tts_backend=="tortoise") reset_generate_settings_button = gr.Button(value="Reset to Default") with gr.Column(visible=False) as col: experimental_column = col @@ -606,10 +608,12 @@ def setup_gradio(): EXEC_SETTINGS['device_override'] = gr.Textbox(label="Device Override", value=args.device_override) EXEC_SETTINGS['results_folder'] = gr.Textbox(label="Results Folder", value=args.results_folder) - - with gr.Column(): # EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0]) + + with gr.Column(visible=args.tts_backend=="vall-e"): + EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0]) + with gr.Column(visible=args.tts_backend=="tortoise"): EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto") EXEC_SETTINGS['diffusion_model'] = gr.Dropdown(choices=diffusion_models, label="Diffusion Model", value=args.diffusion_model if args.diffusion_model else diffusion_models[0]) EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1])