sampler and cond_free selectable in webUI, re-enabled cond_free as default (somehow it's working again)

This commit is contained in:
mrq 2024-06-19 17:12:28 -05:00
parent 73f271fb8a
commit 96b74f38ef
3 changed files with 11 additions and 15 deletions

View File

@ -23,6 +23,7 @@ def main():
parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--diffusion-sampler", type=str, default="ddim") parser.add_argument("--diffusion-sampler", type=str, default="ddim")
parser.add_argument("--cond-free", action="store_true")
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -59,7 +60,8 @@ def main():
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
diffusion_sampler=args.diffusion_sampler diffusion_sampler=args.diffusion_sampler,
cond_free=args.cond_free,
) )
""" """
language=args.language, language=args.language,

View File

@ -117,6 +117,7 @@ class TTS():
#mirostat_eta=0.1, #mirostat_eta=0.1,
diffusion_sampler="ddim", diffusion_sampler="ddim",
cond_free=True,
out_path=None out_path=None
): ):
@ -129,7 +130,7 @@ class TTS():
diffusion = None diffusion = None
clvp = None clvp = None
vocoder = None vocoder = None
diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=False) diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=cond_free)
autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"] autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"]

View File

@ -67,10 +67,10 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
""" """
if kwargs.pop("dynamic-sampling", False): if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR kwargs['min-diffusion-temp'] = 0.85 if kwargs['diffusion-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
else: else:
kwargs['min-ar-temp'] = -1 kwargs['min-ar-temp'] = -1
kwargs['min-nar-temp'] = -1 kwargs['min-diffusion-temp'] = -1
""" """
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
@ -81,9 +81,6 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--max-diffusion-steps", type=int, default=int(kwargs["max-diffusion-steps"])) parser.add_argument("--max-diffusion-steps", type=int, default=int(kwargs["max-diffusion-steps"]))
""" """
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
""" """
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--diffusion-temp", type=float, default=kwargs["diffusion-temp"]) parser.add_argument("--diffusion-temp", type=float, default=kwargs["diffusion-temp"])
@ -97,6 +94,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"]) parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
parser.add_argument("--cond-free", type=str, default=kwargs["cond-free"])
""" """
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
@ -215,19 +213,14 @@ with ui:
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7): with gr.Column(scale=7):
"""
with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
"""
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["max-ar-steps"] = gr.Slider(value=500, minimum=16, maximum=1200, step=1, label="Maximum AR Steps", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-ar-steps"] = gr.Slider(value=500, minimum=16, maximum=1200, step=1, label="Maximum AR Steps", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-diffusion-steps"] = gr.Slider(value=80, minimum=16, maximum=500, step=1, label="Maximum Diffusion Steps", info="Limits how many steps to perform in the Diffusion pass.") layout["inference"]["inputs"]["max-diffusion-steps"] = gr.Slider(value=80, minimum=16, maximum=500, step=1, label="Maximum Diffusion Steps", info="Limits how many steps to perform in the Diffusion pass.")
layout["inference"]["inputs"]["diffusion-sampler"] = gr.Radio( ["P", "DDIM"], value="DDIM", label="Diffusion Samplers", type="value", info="Sampler to use during the diffusion pass." )
layout["inference"]["inputs"]["cond-free"] = gr.Checkbox(label="Cond. Free", value=True, info="Condition Free diffusion")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.")
""" """
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")