new meme sampler PogChamp new meme sampler PogChamp (it sort of helps?)

This commit is contained in:
mrq 2024-11-12 22:30:09 -06:00
parent 663f07038d
commit 0f2584eba7
9 changed files with 112 additions and 52 deletions

View File

@ -35,6 +35,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-no", type=float, default=0.0)
parser.add_argument("--min-p", type=float, default=0.0) parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
@ -83,7 +84,7 @@ def main():
max_duration=args.max_duration, max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,

View File

@ -1383,6 +1383,9 @@ class Dataset(_Dataset):
def training_(self, value): def training_(self, value):
self.training = value self.training = value
def index(self):
return self.sampler.index() if self.sampler is not None else -1
def __len__(self): def __len__(self):
if self.sampler_type == "group": if self.sampler_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups)) return min(len(self.spkr_groups), self._head or len(self.spkr_groups))

View File

@ -57,19 +57,32 @@ def main():
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--language", type=str, default="en")
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--ar-temp", type=float, default=0.0) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--nar-temp", type=float, default=0.0) parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=0.0) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-steps", type=int, default=25)
parser.add_argument("--max-levels", type=int, default=7)
parser.add_argument("--ar-temperature", type=float, default=1.0)
parser.add_argument("--nar-temperature", type=float, default=0.0)
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--prefix-silence", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=0.0)
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-no", type=float, default=0.0)
parser.add_argument("--min-p", type=float, default=0.0) parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.125) parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--beam-width", type=int, default=0)
@ -89,6 +102,10 @@ def main():
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--refine-on-stop", action="store_true")
# experimental settings
parser.add_argument("--load-from-artifact", type=Path, default=None)
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -135,19 +152,19 @@ def main():
comparison_kwargs["titles"] = ["LoRA", "No LoRA"] comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
comparison_kwargs["disabled"]["use_lora"] = True comparison_kwargs["disabled"]["use_lora"] = True
comparison_kwargs["disabled"]["ar_temp"] = 0.0 comparison_kwargs["disabled"]["ar_temperature"] = 0.0
comparison_kwargs["enabled"]["use_lora"] = False comparison_kwargs["enabled"]["use_lora"] = False
comparison_kwargs["enabled"]["ar_temp"] = 0.95 comparison_kwargs["enabled"]["ar_temperature"] = 0.95
elif args.comparison == "entropix-sampling": elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling" comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["disabled"]["entropix_sampling"] = False comparison_kwargs["disabled"]["entropix_sampling"] = False
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
comparison_kwargs["disabled"]["top_k"] = args.top_k comparison_kwargs["disabled"]["top_k"] = args.top_k
comparison_kwargs["disabled"]["top_p"] = args.top_p comparison_kwargs["disabled"]["top_p"] = args.top_p
comparison_kwargs["enabled"]["entropix_sampling"] = True comparison_kwargs["enabled"]["entropix_sampling"] = True
comparison_kwargs["enabled"]["ar_temp"] = 0.666 comparison_kwargs["enabled"]["ar_temperature"] = 0.666
comparison_kwargs["enabled"]["top_k"] = 27 comparison_kwargs["enabled"]["top_k"] = 27
comparison_kwargs["enabled"]["top_p"] = 0.9 comparison_kwargs["enabled"]["top_p"] = 0.9
elif args.comparison == "layerskip": elif args.comparison == "layerskip":
@ -163,14 +180,14 @@ def main():
comparison_kwargs["disabled"]["refine_on_stop"] = False comparison_kwargs["disabled"]["refine_on_stop"] = False
comparison_kwargs["enabled"]["refine_on_stop"] = True comparison_kwargs["enabled"]["refine_on_stop"] = True
elif args.comparison == "ar-temp": elif args.comparison == "ar-temp":
current_temp = args.ar_temp current_temperature = args.ar_temperature
other_temp = 1.0 other_temperature = 1.0
comparison_kwargs["suffix"] = "temperature" comparison_kwargs["suffix"] = "temperature"
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"] comparison_kwargs["titles"] = [f"Temp: {current_temperature:.2f}", f"Temp: {other_temperature:.2f}"]
comparison_kwargs["disabled"]["ar_temp"] = current_temp comparison_kwargs["disabled"]["ar_temperature"] = current_temperature
comparison_kwargs["enabled"]["ar_temp"] = other_temp comparison_kwargs["enabled"]["ar_temperature"] = other_temperature
elif args.comparison == "input-prompt-length": elif args.comparison == "input-prompt-length":
current_length = args.input_prompt_length current_length = args.input_prompt_length
other_length = 3.0 other_length = 3.0
@ -209,21 +226,34 @@ def main():
# read html template # read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
# replace values in our template sampling_kwargs = dict(
html = html.replace(r"${PREAMBLE}", args.preamble ) max_steps=args.max_steps,
html = html.replace(r"${SETTINGS}", str(dict( max_levels=args.max_levels,
input_prompt_length=args.input_prompt_length, max_duration=args.max_duration,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
ar_temp=args.ar_temp, nar_temp=args.nar_temp, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling, entropix_sampling=args.entropix_sampling,
)) ) layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix,
prefix_silence=args.prefix_silence,
cfg_strength=args.cfg_strength,
)
# replace values in our template
html = html.replace(r"${PREAMBLE}", args.preamble )
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
# pull from provided samples # pull from provided samples
samples_dirs = { samples_dirs = {
@ -324,18 +354,9 @@ def main():
references=[prompt], references=[prompt],
language=language, language=language,
input_prompt_length=args.input_prompt_length, input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=seed, seed=seed,
tqdm=False, tqdm=False,
**sampling_kwargs,
) )
def safe_inference( out_path=out_path ): def safe_inference( out_path=out_path ):

View File

@ -31,6 +31,8 @@ from .lora import enable_lora
text_task = [ "stt" ] text_task = [ "stt" ]
class AR_NAR(Base): class AR_NAR(Base):
# parse inputs for training
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
def forward_train( def forward_train(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
@ -62,19 +64,17 @@ class AR_NAR(Base):
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on # RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
# force set mask training
if "len" not in self.capabilities:
masking_train_rvq_levels = 0.0
elif "ar" not in self.capabilities:
masking_train_rvq_levels = 1.0
# CFG # CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly # rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5 masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
# force set mask training
if "len" not in self.capabilities:
masking_train_p = 0.0
elif "ar" not in self.capabilities:
masking_train_p = 1.0
# implicitly set it to all levels # implicitly set it to all levels
if not token_dropout_rvq_levels: if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1] token_dropout_rvq_levels = [0, self.resp_levels - 1]
@ -116,9 +116,11 @@ class AR_NAR(Base):
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16) audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
# final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom # cap quant_level if it exceeds its corresponding resp/prom
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
if quant_level >= resps.shape[-1]: if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1 quant_levels[i] = resps.shape[-1] - 1

View File

@ -1706,7 +1706,9 @@ class Base(nn.Module):
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0) dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
dry_base = sampling_kwargs.get("dry_base", 1.75) dry_base = sampling_kwargs.get("dry_base", 1.75)
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2) dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
#
top_no = sampling_kwargs.get("top_no", 1.0)
#
attentions = sampling_kwargs.get("attentions", None) attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits ) batch_size = len( logits )
@ -1792,6 +1794,10 @@ class Base(nn.Module):
elif temperature > 0.0: elif temperature > 0.0:
logits = [ logit / temperature for logit in logits ] logits = [ logit / temperature for logit in logits ]
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
# do DRY sampling # do DRY sampling
if dry_multiplier > 0.0 and prev_list is not None: if dry_multiplier > 0.0 and prev_list is not None:
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ] logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]

View File

@ -159,6 +159,19 @@ def top_k_logits_list( logits_list, k ):
candidates[i] = tuple(t) candidates[i] = tuple(t)
return candidates return candidates
# top-nσ logit processing
# from https://arxiv.org/abs/2411.07641
def top_no_logits_processing( logits, n = 1.0 ):
M = torch.max(logits, dim=-1, keepdim=True).values
σ = torch.std(logits, dim=-1, keepdim=True)
mask = logits >= M - n * σ
n_inf = torch.full_like( logits, -float("inf") )
logits = torch.where( mask, logits, n_inf )
return logits
# Credit to: https://github.com/basusourya/mirostat/ # Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling # performs mirostat-based sampling

View File

@ -44,6 +44,9 @@ class PoolSampler():
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.sample(*args, **kwargs) return self.sample(*args, **kwargs)
def index(self):
return len(self.global_indices) - len(self.current_pool)
def get_state(self): def get_state(self):
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
@ -72,6 +75,9 @@ class OrderedSampler(Sampler):
yield self.position yield self.position
self.position += 1 self.position += 1
def index(self):
return self.position
def get_state(self): def get_state(self):
return { "position": self.position, "length": self.length } return { "position": self.position, "length": self.length }
@ -125,6 +131,9 @@ class BatchedOrderedSampler(Sampler):
yield self.batches[self.position] yield self.batches[self.position]
self.position += 1 self.position += 1
def index(self):
return self.position
def get_state(self): def get_state(self):
return { "position": self.position, "batches": self.batches } return { "position": self.position, "batches": self.batches }
@ -154,6 +163,9 @@ class RandomSampler(Sampler):
yield self.perm[self.position] yield self.perm[self.position]
self.position += 1 self.position += 1
def index(self):
return self.position
def get_state(self): def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }

View File

@ -104,7 +104,7 @@ def _non_blocking_input():
def _make_infinite_epochs(dl): def _make_infinite_epochs(dl):
while True: while True:
#_logger.info("New epoch starts.") #_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
@local_leader_only(default=None) @local_leader_only(default=None)

View File

@ -216,6 +216,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"]) parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
parser.add_argument("--min-p", type=float, default=kwargs["min-p"]) parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
@ -265,7 +266,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
max_duration=args.max_duration, max_duration=args.max_duration,
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
@ -434,8 +435,8 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
@ -451,10 +452,11 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.") layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row(): with gr.Row():