added min-p (really does not seem useful since it's very sensitive), more tweaks to entropix
This commit is contained in:
parent
bef43a0c18
commit
d0ab7d755a
|
@ -30,6 +30,7 @@ def main():
|
|||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
|
@ -62,7 +63,7 @@ def main():
|
|||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
|
|
@ -192,6 +192,7 @@ class TTS():
|
|||
#
|
||||
top_p=1.0,
|
||||
top_k=0,
|
||||
min_p=0.0,
|
||||
#
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty_decay=0.0,
|
||||
|
@ -245,7 +246,7 @@ class TTS():
|
|||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
|
@ -289,7 +290,7 @@ class TTS():
|
|||
input_prompt_prefix=input_prompt_prefix,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
|
@ -308,7 +309,7 @@ class TTS():
|
|||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
|
@ -320,7 +321,7 @@ class TTS():
|
|||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
|
|
|
@ -47,6 +47,7 @@ class AR(Base):
|
|||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_min_p: float = 0.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
|
@ -202,6 +203,7 @@ class AR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
|
|
|
@ -54,6 +54,7 @@ class AR_NAR(Base):
|
|||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_min_p: float = 0.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
|
@ -235,6 +236,7 @@ class AR_NAR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
|
@ -314,6 +316,7 @@ class AR_NAR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
|
|
|
@ -45,6 +45,9 @@ Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more
|
|||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||
"""
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -1473,6 +1476,7 @@ class Base(nn.Module):
|
|||
min_temperature: float = -1.0, # activates dynamic temperature sampling
|
||||
top_k: int = -100,
|
||||
top_p: float = 1.0,
|
||||
min_p: float = 0.0,
|
||||
# repetition penalty parameters
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty_decay: float = 0.0,
|
||||
|
@ -1508,6 +1512,9 @@ class Base(nn.Module):
|
|||
if attentions is not None:
|
||||
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
|
||||
|
||||
if attentions is not None:
|
||||
entropix_enabled = True
|
||||
|
||||
# this might actually slow things down a bit slightly-er?
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
|
@ -1523,6 +1530,7 @@ class Base(nn.Module):
|
|||
# adjust sample settings
|
||||
cfg = EntropixSamplerConfig()
|
||||
|
||||
entropy[0]["action"] = -1
|
||||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
entropy[0]["action"] = 0
|
||||
|
@ -1551,13 +1559,14 @@ class Base(nn.Module):
|
|||
attn_uncertainty = attn_ent + attn_vent
|
||||
|
||||
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
||||
top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item()
|
||||
top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0))
|
||||
top_k = int(torch.clip(
|
||||
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
|
||||
min=cfg.top_k_min,
|
||||
max=cfg.top_k_max
|
||||
))
|
||||
min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)
|
||||
min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5))
|
||||
temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max )
|
||||
|
||||
def _sample( logits ):
|
||||
# perform repetition penalizing
|
||||
|
@ -1569,6 +1578,9 @@ class Base(nn.Module):
|
|||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
@ -1586,30 +1598,44 @@ class Base(nn.Module):
|
|||
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
|
||||
if entropix_enabled:
|
||||
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
|
||||
|
||||
def score_sample(sample):
|
||||
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
|
||||
log_prob = torch.sum(log_softmax * one_hot)
|
||||
def score_sample(sample):
|
||||
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
|
||||
log_prob = torch.sum(log_softmax * one_hot)
|
||||
|
||||
confidence_score = (
|
||||
(1 - ent) * cfg.ada_score_logits_ent +
|
||||
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||
(1 - vent) * cfg.ada_score_logits_vent +
|
||||
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||
agreement * cfg.ada_score_agree +
|
||||
interaction_strength * cfg.ada_score_int
|
||||
)
|
||||
return log_prob + confidence_score
|
||||
confidence_score = (
|
||||
(1 - ent) * cfg.ada_score_logits_ent +
|
||||
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||
(1 - vent) * cfg.ada_score_logits_vent +
|
||||
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||
agreement * cfg.ada_score_agree +
|
||||
interaction_strength * cfg.ada_score_int
|
||||
)
|
||||
return log_prob + confidence_score
|
||||
|
||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||
|
||||
res = samples[best_sample_idx]
|
||||
scores = sample_scores
|
||||
return Sampled(res, scores, entropy)
|
||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||
|
||||
res = samples[best_sample_idx]
|
||||
scores = sample_scores
|
||||
return Sampled(res, scores, entropy)
|
||||
|
||||
temperature = min(1.5, float(temperature))
|
||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||
min_temperature = temperature
|
||||
|
||||
entropy[0]["temperature"] = temperature
|
||||
entropy[0]["top_k"] = top_k
|
||||
entropy[0]["top_p"] = top_p
|
||||
entropy[0]["min_p"] = min_p
|
||||
|
||||
if not entropix_enabled:
|
||||
temperature = 1.0
|
||||
min_temperature = 1.0
|
||||
top_k = 0
|
||||
top_p = 1.0
|
||||
min_p = 0.0
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
|
@ -1633,6 +1659,10 @@ class Base(nn.Module):
|
|||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
|
|
@ -45,6 +45,7 @@ class NAR(Base):
|
|||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_min_p: float = 0.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
|
@ -191,6 +192,7 @@ class NAR(Base):
|
|||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
|
|
|
@ -50,6 +50,26 @@ def ban_tokens( logits, tokens ):
|
|||
logits[:, token] = -float("inf")
|
||||
return logits
|
||||
|
||||
# Performs min_p filtering
|
||||
# From https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/logits_process.py#L537
|
||||
def min_p_filtering( logits, min_p=0.0, min_tokens_to_keep=32 ):
|
||||
if min_p <= 0.0:
|
||||
return logits
|
||||
|
||||
# Convert logits to probabilities
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
# Get the probability of the top token for each sequence in the batch
|
||||
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
||||
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
|
||||
scaled_min_p = min_p * top_probs
|
||||
|
||||
sorted_indices = torch.argsort(logits, descending=True, dim=-1)
|
||||
sorted_indices_to_remove = torch.gather(probs < scaled_min_p, dim=-1, index=sorted_indices)
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
return logits.masked_fill(indices_to_remove, -float("inf"))
|
||||
|
||||
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
|
@ -243,46 +263,48 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
|||
# to-do: play around with these values
|
||||
@dataclass()
|
||||
class EntropixSamplerConfig:
|
||||
temp: float = 0.999
|
||||
top_p: float = 0.90
|
||||
top_k: int = 32
|
||||
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
|
||||
temp: float = 0.85
|
||||
top_p: float = 0.90
|
||||
top_k: int = 27
|
||||
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
|
||||
|
||||
low_ent_thresh: float = 0.1
|
||||
low_vent_thresh: float = 0.1
|
||||
med_ent_thresh: float = 3.0
|
||||
high_ent_thresh: float = 5.0
|
||||
high_vent_thresh: float = 5.0
|
||||
low_ent_thresh: float = 0.1 # 3.0
|
||||
low_vent_thresh: float = 0.1 # 3.0
|
||||
med_ent_thresh: float = 3.0 # 6.0
|
||||
high_ent_thresh: float = 5.0 # 9.0
|
||||
high_vent_thresh: float = 5.0 # 9.0
|
||||
|
||||
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
|
||||
helv_attn_ent_offset: float = 1.3
|
||||
helv_attn_ent_coef: float = 0.2
|
||||
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
|
||||
helv_attn_ent_offset: float = 1.3
|
||||
helv_attn_ent_coef: float = 0.2
|
||||
|
||||
lehv_interaction_strength_offset: float = 1.2
|
||||
lehv_interaction_strength_coef: float = 0.3
|
||||
lehv_interaction_strength_offset: float = 1.2
|
||||
lehv_interaction_strength_coef: float = 0.3
|
||||
|
||||
hehv_attn_ent_coef: float = 0.2
|
||||
hehv_attn_vent_offset: float = 2.0
|
||||
hehv_attn_vent_coef: float = 0.5
|
||||
hehv_attn_ent_coef: float = 0.2
|
||||
hehv_attn_vent_offset: float = 2.0
|
||||
hehv_attn_vent_coef: float = 0.5
|
||||
|
||||
# TODO not convinced this should
|
||||
n_adaptive_samples: int = 5
|
||||
# TODO not convinced this should
|
||||
n_adaptive_samples: int = 5
|
||||
|
||||
# Adaptive sampling parameters
|
||||
ada_temp_logits: float = 0.3
|
||||
ada_temp_attn: float = 0.2
|
||||
ada_temp_agree: float = 0.2
|
||||
ada_top_p: float = 0.1
|
||||
ada_top_k_int: float = 0.3
|
||||
ada_top_k_agree: float = 0.2
|
||||
ada_min_p: float = 0.5
|
||||
ada_score_logits_ent: float = 0.1
|
||||
ada_score_attn_ent: float = 0.2
|
||||
ada_score_logits_vent: float = 0.3
|
||||
ada_score_attn_vent: float = 0.4
|
||||
ada_score_agree: float = 0.5
|
||||
ada_score_int: float = 0.6
|
||||
# Adaptive sampling parameters
|
||||
ada_temp_logits: float = 0.3
|
||||
ada_temp_attn: float = 0.2
|
||||
ada_temp_agree: float = 0.2
|
||||
ada_top_p: float = 0.1
|
||||
ada_top_k_int: float = 0.3
|
||||
ada_top_k_agree: float = 0.2
|
||||
ada_min_p: float = 0.5
|
||||
ada_score_logits_ent: float = 0.1
|
||||
ada_score_attn_ent: float = 0.2
|
||||
ada_score_logits_vent: float = 0.3
|
||||
ada_score_attn_vent: float = 0.4
|
||||
ada_score_agree: float = 0.5
|
||||
ada_score_int: float = 0.6
|
||||
|
||||
# extra stuff
|
||||
top_k_min: int = 32
|
||||
top_k_max: int = 128
|
||||
# extra stuff
|
||||
top_k_min: int = 1
|
||||
top_k_max: int = 1024
|
||||
temperature_max: float = 1.25
|
||||
temperature_min: float = 0.5
|
|
@ -157,6 +157,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
|
@ -196,6 +197,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
|
@ -228,6 +230,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--min-p", type=int, default=kwargs["min-p"])
|
||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
|
@ -266,6 +269,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
min_ar_temp=args.min_ar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
|
@ -343,6 +347,7 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
|
@ -382,6 +387,7 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user