From e727b6e5c1a06bca55d8dc5d8faa78449a18e309 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 10 Oct 2023 17:02:33 -0500 Subject: [PATCH] changed dynamic temperature trigger to be a min-(n)ar-temp value between [0,(n)ar-temp), flags to set min temp, checkbox in web UI to request it --- vall_e/__main__.py | 3 +++ vall_e/inference.py | 4 ++++ vall_e/models/ar.py | 2 ++ vall_e/models/ar_nar.py | 7 +++++-- vall_e/models/base.py | 10 +++++++--- vall_e/models/nar.py | 2 ++ vall_e/webui.py | 10 ++++++++++ 7 files changed, 33 insertions(+), 5 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index a900e18..b3a2875 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -20,6 +20,8 @@ def main(): parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--nar-temp", type=float, default=1.0) + parser.add_argument("--min-ar-temp", type=float, default=-1.0) + parser.add_argument("--min-nar-temp", type=float, default=-1.0) parser.add_argument("--input-prompt-length", type=float, default=3.0) parser.add_argument("--top-p", type=float, default=1.0) @@ -45,6 +47,7 @@ def main(): input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, diff --git a/vall_e/inference.py b/vall_e/inference.py index 34df7e5..e3701d2 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -152,6 +152,8 @@ class TTS(): input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, + min_ar_temp=0.95, + min_nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, @@ -175,6 +177,7 @@ class TTS(): resps_list = self.ar( text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, + sampling_min_temperature=min_ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, @@ -187,6 +190,7 @@ class TTS(): text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, + sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, ) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 79122df..fd6896b 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -105,6 +105,7 @@ class AR(Base): max_steps: int = 1000, sampling_temperature: float = 1.0, + sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, @@ -162,6 +163,7 @@ class AR(Base): resps_list=resps_list, temperature=sampling_temperature, + min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, repetition_penalty=sampling_repetition_penalty, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 6731261..70889ef 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -88,7 +88,8 @@ class AR_NAR(Base): resps_list: list[Tensor] | None = None, max_steps: int = 1000, max_levels: int = 7, - sampling_temperature: float = 0.0, + sampling_temperature: float = 1.0, + sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, @@ -154,6 +155,7 @@ class AR_NAR(Base): quant_levels=quant_levels, temperature=sampling_temperature, + min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, repetition_penalty=sampling_repetition_penalty, @@ -198,6 +200,7 @@ class AR_NAR(Base): resps_list=resps_list, temperature=sampling_temperature, + min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, repetition_penalty=sampling_repetition_penalty, @@ -320,7 +323,7 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=600 ): engine.eval() - resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95, sampling_beam_width=16 ) + resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) for i, o in enumerate(resps_list): _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c93eacf..ec769f0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -540,6 +540,7 @@ class Base(nn.Module): quant_levels: Tensor | None = None, temperature: float = 1.0, + min_temperature: float = -1.0, top_k: int = -100, top_p: float = 1.0, @@ -552,6 +553,8 @@ class Base(nn.Module): mirostat: list[dict] | None = None, ): + if min_temperature < 0: + min_temperature = temperature # (NAR) return the entire generated response if quant_levels is not None: logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] @@ -576,9 +579,10 @@ class Base(nn.Module): if top_k > 0 or top_p < 1.0: logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] - # our dynamic temperature threshold is considered to be anything over 1.0. - if temperature > 1.0: - logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ] + # trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature + # epsilon float comparison because I don't trust Python + if abs(temperature - min_temperature) >= 0.001: + logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ] else: logits = [ logit / temperature for logit in logits ] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 100bdfa..c34ddb9 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -76,6 +76,7 @@ class NAR(Base): resps_list: list[Tensor], max_levels: int = 0, sampling_temperature: float = 0.2, + sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, sampling_repetition_penalty: float = 1.0, @@ -147,6 +148,7 @@ class NAR(Base): quant_levels=quant_levels, temperature=sampling_temperature, + min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, repetition_penalty=sampling_repetition_penalty, diff --git a/vall_e/webui.py b/vall_e/webui.py index 6a23482..f65e027 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -64,6 +64,10 @@ def init_tts(restart=False): @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + if kwargs.pop("dynamic-sampling", False): + kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 + kwargs['min-nar-temp'] = 0.2 if kwargs['nar-temp'] > 0.2 else 0.0 + parser = argparse.ArgumentParser(allow_abbrev=False) # I'm very sure I can procedurally generate this list parser.add_argument("--text", type=str, default=kwargs["text"]) @@ -73,6 +77,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) + parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) + parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) @@ -99,6 +105,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): input_prompt_length=args.input_prompt_length, ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, + min_nar_temp=args.min_nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, @@ -192,6 +200,8 @@ with ui: with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.") + with gr.Row(): + layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") with gr.Row(): layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")