From 92e6bff6dc96fd8679c58e487f7742a02a4eef8f Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 23 Oct 2024 00:03:35 -0500 Subject: [PATCH] actually ar temp 0.5 with rep pen 1.125 seems to have the benefits of better outputs without it degrading some of the time but not all the time --- vall_e/__main__.py | 2 +- vall_e/models/ar_nar.py | 4 ++-- vall_e/webui.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 362de39..c9edacc 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -21,7 +21,7 @@ def main(): parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-nar-levels", type=int, default=7) - parser.add_argument("--ar-temp", type=float, default=0.0) + parser.add_argument("--ar-temp", type=float, default=0.5) parser.add_argument("--nar-temp", type=float, default=0.0) parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--min-nar-temp", type=float, default=-1.0) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 7e81818..9c2c51a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -368,7 +368,7 @@ class AR_NAR(Base): mirostat = sampled.scores elif sampling_beam_width > 0: # expand tuple - scores = sampled.scores + s = sampled.scores # first step, expand batch if batch_size == 1: batch_size = sampling_beam_width @@ -379,7 +379,7 @@ class AR_NAR(Base): start_slice = start_slice * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() - scores = [ scores[i] + score for i, score in enumerate(scores) ] + scores = [ scores[i] + score for i, score in enumerate(s) ] # append tokens for i, ri in enumerate(r): diff --git a/vall_e/webui.py b/vall_e/webui.py index 2bde04b..c8b71c1 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -351,7 +351,7 @@ with ui: layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): - layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") + layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): if cfg.experimental: