changed dynamic temperature trigger to be a min-(n)ar-temp value between [0,(n)ar-temp), flags to set min temp, checkbox in web UI to request it
This commit is contained in:
parent
ec25f56bd9
commit
e727b6e5c1
|
@ -20,6 +20,8 @@ def main():
|
|||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
|
@ -45,6 +47,7 @@ def main():
|
|||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
|
|
|
@ -152,6 +152,8 @@ class TTS():
|
|||
input_prompt_length=0.0,
|
||||
ar_temp=0.95,
|
||||
nar_temp=0.5,
|
||||
min_ar_temp=0.95,
|
||||
min_nar_temp=0.5,
|
||||
top_p=1.0,
|
||||
top_k=0,
|
||||
repetition_penalty=1.0,
|
||||
|
@ -175,6 +177,7 @@ class TTS():
|
|||
resps_list = self.ar(
|
||||
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
|
@ -187,6 +190,7 @@ class TTS():
|
|||
text_list=[phns], proms_list=[prom], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
|
|
|
@ -105,6 +105,7 @@ class AR(Base):
|
|||
max_steps: int = 1000,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
|
@ -162,6 +163,7 @@ class AR(Base):
|
|||
resps_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
|
|
|
@ -88,7 +88,8 @@ class AR_NAR(Base):
|
|||
resps_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 7,
|
||||
sampling_temperature: float = 0.0,
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
|
@ -154,6 +155,7 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
|
@ -198,6 +200,7 @@ class AR_NAR(Base):
|
|||
resps_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
|
@ -320,7 +323,7 @@ def example_usage():
|
|||
@torch.inference_mode()
|
||||
def sample( name, steps=600 ):
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95, sampling_beam_width=16 )
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
|
|
@ -540,6 +540,7 @@ class Base(nn.Module):
|
|||
quant_levels: Tensor | None = None,
|
||||
|
||||
temperature: float = 1.0,
|
||||
min_temperature: float = -1.0,
|
||||
top_k: int = -100,
|
||||
top_p: float = 1.0,
|
||||
|
||||
|
@ -552,6 +553,8 @@ class Base(nn.Module):
|
|||
|
||||
mirostat: list[dict] | None = None,
|
||||
):
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
# (NAR) return the entire generated response
|
||||
if quant_levels is not None:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
|
@ -576,9 +579,10 @@ class Base(nn.Module):
|
|||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# our dynamic temperature threshold is considered to be anything over 1.0.
|
||||
if temperature > 1.0:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
|
||||
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
|
||||
# epsilon float comparison because I don't trust Python
|
||||
if abs(temperature - min_temperature) >= 0.001:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
||||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ class NAR(Base):
|
|||
resps_list: list[Tensor],
|
||||
max_levels: int = 0,
|
||||
sampling_temperature: float = 0.2,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
|
@ -147,6 +148,7 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
|
|
|
@ -64,6 +64,10 @@ def init_tts(restart=False):
|
|||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.2 if kwargs['nar-temp'] > 0.2 else 0.0
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
|
@ -73,6 +77,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
||||
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
|
@ -99,6 +105,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
input_prompt_length=args.input_prompt_length,
|
||||
ar_temp=args.ar_temp,
|
||||
nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp,
|
||||
min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
|
@ -192,6 +200,8 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user