changed dynamic temperature trigger to be a min-(n)ar-temp value between [0,(n)ar-temp), flags to set min temp, checkbox in web UI to request it

This commit is contained in:
mrq 2023-10-10 17:02:33 -05:00
parent ec25f56bd9
commit e727b6e5c1
7 changed files with 33 additions and 5 deletions

View File

@ -20,6 +20,8 @@ def main():
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0)
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--top-p", type=float, default=1.0)
@ -45,6 +47,7 @@ def main():
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,

View File

@ -152,6 +152,8 @@ class TTS():
input_prompt_length=0.0,
ar_temp=0.95,
nar_temp=0.5,
min_ar_temp=0.95,
min_nar_temp=0.5,
top_p=1.0,
top_k=0,
repetition_penalty=1.0,
@ -175,6 +177,7 @@ class TTS():
resps_list = self.ar(
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
@ -187,6 +190,7 @@ class TTS():
text_list=[phns], proms_list=[prom], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)

View File

@ -105,6 +105,7 @@ class AR(Base):
max_steps: int = 1000,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
@ -162,6 +163,7 @@ class AR(Base):
resps_list=resps_list,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,

View File

@ -88,7 +88,8 @@ class AR_NAR(Base):
resps_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_levels: int = 7,
sampling_temperature: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
@ -154,6 +155,7 @@ class AR_NAR(Base):
quant_levels=quant_levels,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
@ -198,6 +200,7 @@ class AR_NAR(Base):
resps_list=resps_list,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
@ -320,7 +323,7 @@ def example_usage():
@torch.inference_mode()
def sample( name, steps=600 ):
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95, sampling_beam_width=16 )
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)

View File

@ -540,6 +540,7 @@ class Base(nn.Module):
quant_levels: Tensor | None = None,
temperature: float = 1.0,
min_temperature: float = -1.0,
top_k: int = -100,
top_p: float = 1.0,
@ -552,6 +553,8 @@ class Base(nn.Module):
mirostat: list[dict] | None = None,
):
if min_temperature < 0:
min_temperature = temperature
# (NAR) return the entire generated response
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
@ -576,9 +579,10 @@ class Base(nn.Module):
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# our dynamic temperature threshold is considered to be anything over 1.0.
if temperature > 1.0:
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:
logits = [ logit / temperature for logit in logits ]

View File

@ -76,6 +76,7 @@ class NAR(Base):
resps_list: list[Tensor],
max_levels: int = 0,
sampling_temperature: float = 0.2,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
@ -147,6 +148,7 @@ class NAR(Base):
quant_levels=quant_levels,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,

View File

@ -64,6 +64,10 @@ def init_tts(restart=False):
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.2 if kwargs['nar-temp'] > 0.2 else 0.0
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
@ -73,6 +77,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
@ -99,6 +105,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
input_prompt_length=args.input_prompt_length,
ar_temp=args.ar_temp,
nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp,
min_nar_temp=args.min_nar_temp,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
@ -192,6 +200,8 @@ with ui:
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
with gr.Row():
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")