new meme sampler PogChamp new meme sampler PogChamp (it sort of helps?)
This commit is contained in:
parent
663f07038d
commit
0f2584eba7
|
@ -35,6 +35,7 @@ def main():
|
|||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--top-no", type=float, default=0.0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
|
@ -83,7 +84,7 @@ def main():
|
|||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
|
|
@ -1383,6 +1383,9 @@ class Dataset(_Dataset):
|
|||
def training_(self, value):
|
||||
self.training = value
|
||||
|
||||
def index(self):
|
||||
return self.sampler.index() if self.sampler is not None else -1
|
||||
|
||||
def __len__(self):
|
||||
if self.sampler_type == "group":
|
||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||
|
|
|
@ -57,37 +57,54 @@ def main():
|
|||
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=0.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=0.0)
|
||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=0.0)
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--lora", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-steps", type=int, default=25)
|
||||
parser.add_argument("--max-levels", type=int, default=7)
|
||||
|
||||
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temperature", type=float, default=0.0)
|
||||
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
parser.add_argument("--input-prompt-prefix", action="store_true")
|
||||
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--top-no", type=float, default=0.0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
|
||||
parser.add_argument("--dry-multiplier", type=float, default=0)
|
||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
|
||||
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
|
||||
# experimental settings
|
||||
parser.add_argument("--load-from-artifact", type=Path, default=None)
|
||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
|
@ -135,19 +152,19 @@ def main():
|
|||
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
||||
|
||||
comparison_kwargs["disabled"]["use_lora"] = True
|
||||
comparison_kwargs["disabled"]["ar_temp"] = 0.0
|
||||
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
|
||||
comparison_kwargs["enabled"]["use_lora"] = False
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 0.95
|
||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
|
||||
elif args.comparison == "entropix-sampling":
|
||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
|
||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
||||
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 0.666
|
||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
|
||||
comparison_kwargs["enabled"]["top_k"] = 27
|
||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||
elif args.comparison == "layerskip":
|
||||
|
@ -163,14 +180,14 @@ def main():
|
|||
comparison_kwargs["disabled"]["refine_on_stop"] = False
|
||||
comparison_kwargs["enabled"]["refine_on_stop"] = True
|
||||
elif args.comparison == "ar-temp":
|
||||
current_temp = args.ar_temp
|
||||
other_temp = 1.0
|
||||
current_temperature = args.ar_temperature
|
||||
other_temperature = 1.0
|
||||
|
||||
comparison_kwargs["suffix"] = "temperature"
|
||||
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
|
||||
comparison_kwargs["titles"] = [f"Temp: {current_temperature:.2f}", f"Temp: {other_temperature:.2f}"]
|
||||
|
||||
comparison_kwargs["disabled"]["ar_temp"] = current_temp
|
||||
comparison_kwargs["enabled"]["ar_temp"] = other_temp
|
||||
comparison_kwargs["disabled"]["ar_temperature"] = current_temperature
|
||||
comparison_kwargs["enabled"]["ar_temperature"] = other_temperature
|
||||
elif args.comparison == "input-prompt-length":
|
||||
current_length = args.input_prompt_length
|
||||
other_length = 3.0
|
||||
|
@ -209,21 +226,34 @@ def main():
|
|||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
|
||||
# replace values in our template
|
||||
html = html.replace(r"${PREAMBLE}", args.preamble )
|
||||
html = html.replace(r"${SETTINGS}", str(dict(
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
sampling_kwargs = dict(
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
)) )
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
denoise_start=args.denoise_start,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
input_prompt_prefix=args.input_prompt_prefix,
|
||||
prefix_silence=args.prefix_silence,
|
||||
cfg_strength=args.cfg_strength,
|
||||
)
|
||||
|
||||
# replace values in our template
|
||||
html = html.replace(r"${PREAMBLE}", args.preamble )
|
||||
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
|
||||
|
||||
# pull from provided samples
|
||||
samples_dirs = {
|
||||
|
@ -324,18 +354,9 @@ def main():
|
|||
references=[prompt],
|
||||
language=language,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
seed=seed,
|
||||
tqdm=False,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
def safe_inference( out_path=out_path ):
|
||||
|
|
|
@ -31,6 +31,8 @@ from .lora import enable_lora
|
|||
text_task = [ "stt" ]
|
||||
|
||||
class AR_NAR(Base):
|
||||
# parse inputs for training
|
||||
# a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training
|
||||
def forward_train(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
|
@ -62,19 +64,17 @@ class AR_NAR(Base):
|
|||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# RVQ levels to apply masking training on
|
||||
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
|
||||
|
||||
# force set mask training
|
||||
if "len" not in self.capabilities:
|
||||
masking_train_rvq_levels = 0.0
|
||||
elif "ar" not in self.capabilities:
|
||||
masking_train_rvq_levels = 1.0
|
||||
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||
# rate to train RVQ level AR-ly or NAR-ly
|
||||
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
||||
# force set mask training
|
||||
if "len" not in self.capabilities:
|
||||
masking_train_p = 0.0
|
||||
elif "ar" not in self.capabilities:
|
||||
masking_train_p = 1.0
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
|
@ -116,9 +116,11 @@ class AR_NAR(Base):
|
|||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
|
||||
# final validations and stuff
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
|
|
|
@ -1706,7 +1706,9 @@ class Base(nn.Module):
|
|||
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
|
||||
dry_base = sampling_kwargs.get("dry_base", 1.75)
|
||||
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
|
||||
|
||||
#
|
||||
top_no = sampling_kwargs.get("top_no", 1.0)
|
||||
#
|
||||
attentions = sampling_kwargs.get("attentions", None)
|
||||
|
||||
batch_size = len( logits )
|
||||
|
@ -1792,6 +1794,10 @@ class Base(nn.Module):
|
|||
elif temperature > 0.0:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do top-no logit processing
|
||||
if top_no > 0.0:
|
||||
logits = [ top_no_logits_processing(logit) for logit in logits ]
|
||||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0 and prev_list is not None:
|
||||
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
|
|
@ -159,6 +159,19 @@ def top_k_logits_list( logits_list, k ):
|
|||
candidates[i] = tuple(t)
|
||||
return candidates
|
||||
|
||||
# top-nσ logit processing
|
||||
# from https://arxiv.org/abs/2411.07641
|
||||
def top_no_logits_processing( logits, n = 1.0 ):
|
||||
M = torch.max(logits, dim=-1, keepdim=True).values
|
||||
σ = torch.std(logits, dim=-1, keepdim=True)
|
||||
|
||||
mask = logits >= M - n * σ
|
||||
n_inf = torch.full_like( logits, -float("inf") )
|
||||
logits = torch.where( mask, logits, n_inf )
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
# Credit to: https://github.com/basusourya/mirostat/
|
||||
# performs mirostat-based sampling
|
||||
|
|
|
@ -44,6 +44,9 @@ class PoolSampler():
|
|||
def __call__(self, *args, **kwargs):
|
||||
return self.sample(*args, **kwargs)
|
||||
|
||||
def index(self):
|
||||
return len(self.global_indices) - len(self.current_pool)
|
||||
|
||||
def get_state(self):
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||
|
||||
|
@ -72,6 +75,9 @@ class OrderedSampler(Sampler):
|
|||
yield self.position
|
||||
self.position += 1
|
||||
|
||||
def index(self):
|
||||
return self.position
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length }
|
||||
|
||||
|
@ -125,6 +131,9 @@ class BatchedOrderedSampler(Sampler):
|
|||
yield self.batches[self.position]
|
||||
self.position += 1
|
||||
|
||||
def index(self):
|
||||
return self.position
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "batches": self.batches }
|
||||
|
||||
|
@ -154,6 +163,9 @@ class RandomSampler(Sampler):
|
|||
yield self.perm[self.position]
|
||||
self.position += 1
|
||||
|
||||
def index(self):
|
||||
return self.position
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ def _non_blocking_input():
|
|||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
#_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader())
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True, disable=not is_global_leader()) # , initial=dl.dataset.index(), total=len(dl.dataset)) # to-do: figure out why this number jumps
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
|
|
|
@ -216,6 +216,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--prefix-silence", type=float, default=kwargs["prefix-silence"])
|
||||
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||
parser.add_argument("--top-no", type=float, default=kwargs["top-no"])
|
||||
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
|
||||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
|
@ -265,7 +266,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
max_duration=args.max_duration,
|
||||
ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
|
||||
min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, top_no=args.top_no,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
@ -434,8 +435,8 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
|
||||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
|
@ -451,10 +452,11 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.")
|
||||
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user