From 811539b20adfe6d85d2bc3e6728d55fd2427aae0 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 10 Feb 2023 16:47:57 +0000 Subject: [PATCH] Added the remaining input settings --- webui.py | 74 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/webui.py b/webui.py index e868528..c4651e4 100755 --- a/webui.py +++ b/webui.py @@ -24,7 +24,29 @@ args = None webui = None tts = None -def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidates, num_autoregressive_samples, diffusion_iterations, temperature, diffusion_sampler, breathing_room, cvvp_weight, experimentals, progress=gr.Progress(track_tqdm=True)): +def generate( + text, + delimiter, + emotion, + prompt, + voice, + mic_audio, + seed, + candidates, + num_autoregressive_samples, + diffusion_iterations, + temperature, + diffusion_sampler, + breathing_room, + cvvp_weight, + top_p, + diffusion_temperature, + length_penalty, + repetition_penalty, + cond_free_k, + experimental_checkboxes, + progress=gr.Progress(track_tqdm=True) +): try: tts except NameError: @@ -67,9 +89,13 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate settings = { - 'temperature': temperature, 'length_penalty': 1.0, 'repetition_penalty': 2.0, - 'top_p': .8, - 'cond_free_k': 2.0, 'diffusion_temperature': 1.0, + 'temperature': float(temperature), + + 'top_p': float(top_p), + 'diffusion_temperature': float(diffusion_temperature), + 'length_penalty': float(length_penalty), + 'repetition_penalty': float(repetition_penalty), + 'cond_free_k': float(cond_free_k), 'num_autoregressive_samples': num_autoregressive_samples, 'sample_batch_size': args.sample_batch_size, @@ -83,8 +109,8 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate 'diffusion_sampler': diffusion_sampler, 'breathing_room': breathing_room, 'progress': progress, - 'half_p': "Half Precision" in experimentals, - 'cond_free': "Conditioning-Free" in experimentals, + 'half_p': "Half Precision" in experimental_checkboxes, + 'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cvvp_amount': cvvp_weight, } @@ -196,7 +222,12 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate 'diffusion_sampler': diffusion_sampler, 'breathing_room': breathing_room, 'cvvp_weight': cvvp_weight, - 'experimentals': experimentals, + 'top_p': top_p, + 'diffusion_temperature': diffusion_temperature, + 'length_penalty': length_penalty, + 'repetition_penalty': repetition_penalty, + 'cond_free_k': cond_free_k, + 'experimentals': experimental_checkboxes, 'time': time.time()-start_time, } @@ -294,10 +325,15 @@ def import_generate_settings(file="./config/generate.json"): None if 'candidates' not in settings else settings['candidates'], None if 'num_autoregressive_samples' not in settings else settings['num_autoregressive_samples'], None if 'diffusion_iterations' not in settings else settings['diffusion_iterations'], - None if 'temperature' not in settings else settings['temperature'], - None if 'diffusion_sampler' not in settings else settings['diffusion_sampler'], - None if 'breathing_room' not in settings else settings['breathing_room'], - None if 'cvvp_weight' not in settings else settings['cvvp_weight'], + 0.8 if 'temperature' not in settings else settings['temperature'], + "DDIM" if 'diffusion_sampler' not in settings else settings['diffusion_sampler'], + 8.0 if 'breathing_room' not in settings else settings['breathing_room'], + 0.0 if 'cvvp_weight' not in settings else settings['cvvp_weight'], + 0.8 if 'top_p' not in settings else settings['top_p'], + 1.0 if 'diffusion_temperature' not in settings else settings['diffusion_temperature'], + 1.0 if 'length_penalty' not in settings else settings['length_penalty'], + 2.0 if 'repetition_penalty' not in settings else settings['repetition_penalty'], + 2.0 if 'cond_free_k' not in settings else settings['cond_free_k'], None if 'experimentals' not in settings else settings['experimentals'], ) @@ -568,6 +604,7 @@ def setup_gradio(): gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean), ] gr.Button(value="Check for Updates").click(check_for_updates) + gr.Button(value="Reload TTS").click(reload_tts) with gr.Column(): exec_inputs = exec_inputs + [ gr.Number(label="Voice Latents Max Chunk Size", precision=0, value=args.cond_latent_max_chunk_size), @@ -583,10 +620,14 @@ def setup_gradio(): inputs=exec_inputs ) with gr.Column(): - experimentals = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags") + experimental_checkboxes = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags") cvvp_weight = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight") + top_p = gr.Slider(value=0.8, minimum=0, maximum=2, label="Top P") + diffusion_temperature = gr.Slider(value=1.0, minimum=0, maximum=2, label="Diffusion Temperature") + length_penalty = gr.Slider(value=1.0, minimum=0, maximum=8, label="Length Penalty") + repetition_penalty = gr.Slider(value=2.0, minimum=0, maximum=8, label="Repetition Penalty") + cond_free_k = gr.Slider(value=2.0, minimum=0, maximum=8, label="Conditioning-Free K") - gr.Button(value="Reload TTS").click(reload_tts) input_settings = [ text, @@ -603,7 +644,12 @@ def setup_gradio(): diffusion_sampler, breathing_room, cvvp_weight, - experimentals, + top_p, + diffusion_temperature, + length_penalty, + repetition_penalty, + cond_free_k, + experimental_checkboxes, ] submit_event = submit.click(generate,