logic fixes, I feel like output is better? (also NAR can have a temperature, I imagine it couldn't because it was having a causal masked passed to it for the longest time before I caught it a month ago)

This commit is contained in:
mrq 2024-12-08 14:52:47 -06:00
parent 0c5a458b00
commit 1d460b9fe3
5 changed files with 70 additions and 110 deletions

View File

@ -61,7 +61,7 @@ def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
# nltk does not split quotations all that nicely, so we coerce them into placeholders, then replace afterwards
s = s.replace('"', quote_placeholder)
sentences = nltk.sent_tokenize(s)
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences ]
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ]
@cache
def get_random_prompts( validation=False, min_length=0, tokenized=False ):

View File

@ -123,6 +123,7 @@ def main():
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--attention", type=str, default="auto")
parser.add_argument("--random-prompts", action="store_true")
parser.add_argument("--lora", action="store_true")
@ -136,7 +137,7 @@ def main():
elif args.model:
config = args.model
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp )
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
if not args.demo_dir:
args.demo_dir = Path("./data/demo/")

View File

@ -300,8 +300,6 @@ class TTS():
if model_len is not None:
# extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds
@ -443,8 +441,6 @@ class TTS():
if model_len is not None:
# extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds

View File

@ -223,31 +223,19 @@ class AR_NAR(Base):
device = text_list[0].device
batch_size = len(text_list)
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
# to-do: check if gumbel sampling works / helps
"""
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
"""
def log(t, eps=1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
"""
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
@ -261,18 +249,18 @@ class AR_NAR(Base):
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", 2.0)
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
remasking = sampling_kwargs.get("remasking", False)
max_steps = math.floor(max_steps * (end_noise - start_noise))
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0)
original_cfg_strength = cfg_strength
cfg_strength = max( cfg_strength, minimum_cfg_strength )
@ -320,17 +308,19 @@ class AR_NAR(Base):
time_list = [ timestep for _ in range(batch_size) ]
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
# avoid useless CFG sampling
"""
if sampling_cfg < minimum_cfg_strength * 0.5:
sampling_cfg = 0
"""
if prefix_context is not None:
input_resps_list = [ torch.concat( [ prefix, resps ] ) for prefix, resps in zip( prefix_resps_list, resps_list ) ]
# originally requested no CFG, safe to ignore if we have a prefix
if original_cfg_strength == 0:
sampling_cfg = 0
if original_cfg_strength < minimum_cfg_strength:
sampling_cfg = original_cfg_strength * timestep if annealed_sampling else original_cfg_strength
else:
input_resps_list = resps_list
@ -347,7 +337,6 @@ class AR_NAR(Base):
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
@ -365,7 +354,6 @@ class AR_NAR(Base):
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
@ -420,49 +408,6 @@ class AR_NAR(Base):
use_lora=None,
**sampling_kwargs,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# convert NAR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
max_levels = sampling_kwargs.get("max_levels", 0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
if max_levels == 0:
max_levels = self.n_max_levels - 1
# prefixed context provided
"""
prefix_context = sampling_kwargs.get("prefix_context", None)
if prefix_context is not None:
prefix_text, prefix_resps, _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
# feeding this into the NAR-len should automatically handle things
resps_list = [ resps for resps in prefix_resps ]
"""
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
"""
# inference NAR level 0
if len_list is not None:
resps_list = self.forward_nar_masked(
@ -475,6 +420,26 @@ class AR_NAR(Base):
len_list=len_list,
**sampling_kwargs,
)
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# convert NAR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
max_levels = sampling_kwargs.get("max_levels", 0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
if max_levels == 0:
max_levels = self.n_max_levels - 1
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
@ -489,14 +454,14 @@ class AR_NAR(Base):
iterator = trange( max_levels, desc="NAR", disable=disable_tqdm )
for n in iterator:
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
if level >= max_levels + 1:
iterator.close()
break
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
quant_levels = [ level for _ in range(batch_size) ]
inputs = self.inputs(
text_list=text_list,
@ -510,7 +475,6 @@ class AR_NAR(Base):
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits, state = output.logits, output.state
@ -526,7 +490,6 @@ class AR_NAR(Base):
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] for resp in resps_list ] )
@ -535,8 +498,7 @@ class AR_NAR(Base):
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
#temperature=0.0,
**(sampling_kwargs | {"temperature": 0.0}),
**(sampling_kwargs),
)
resps_list = sampled.ids

View File

@ -440,31 +440,43 @@ with ui:
with gr.Column(scale=7):
with gr.Tab("Basic Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Duration", info="Limits how many steps to perform in the AR pass.")
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=50, minimum=1, maximum=200, step=1, label="Max Steps (NAR-len)", info="Limits how many steps to perform in the NAR-len (demask) pass.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=50, minimum=1, maximum=200, step=1, label="Max Steps", info="Limits how many steps to perform in the NAR-len (demask) pass.")
layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Duration", info="Limits how long an utterance can be.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.5, label="Input Prompt Repeat/Trim Length", info="Repeats/trims the input prompt down to X seconds (0 to disable).")
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Modifies the randomness from the samples in the AR/NAR-len. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="auto", info="Target language/accent to output.")
layout["inference_tts"]["inputs"]["text-language"] = gr.Dropdown(choices=get_languages(), label="Language (Text)", value="auto", info="Language the input text is in.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="auto", info="Target language/accent to output.")
with gr.Row():
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences")
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Adjusts the probabilities in the NAR. (0 to greedy sample)")
layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.")
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.5, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
with gr.Row():
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P", info="Filter out logits lower than this value.")
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.")
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P", info="Filter out logits lower than this value.")
layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.5, label="Top-nσ", info="Performs top-nσ logits processing.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
# These settings are pretty much not supported anyways
with gr.Tab("Experimental Settings", visible=cfg.experimental):
with gr.Row():
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.5, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
with gr.Row():
layout["inference_tts"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference_tts"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
@ -472,22 +484,11 @@ with ui:
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
with gr.Tab("Experimental Settings", visible=cfg.experimental):
with gr.Row():
layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row():
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
with gr.Row(visible=False):
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
layout["inference_tts"]["buttons"]["inference"].click(
fn=do_inference_tts,
@ -508,10 +509,8 @@ with ui:
with gr.Tab("Basic Settings"):
with gr.Row():
layout["inference_stt"]["inputs"]["ar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en", info="Language of the input audio being transcribed.")
with gr.Tab("Sampler Settings", visible=False):
with gr.Row():
layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
@ -522,6 +521,7 @@ with ui:
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference_stt"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
with gr.Row():
@ -570,14 +570,15 @@ with ui:
if not USING_SPACES:
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model", info="Model to load. Can load from a config YAML or the weights itself.")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.")
with gr.Row():
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision", info="Tensor type to load the model under.")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions", info="Attention mechanism to utilize.")
layout["settings"]["buttons"]["load"].click(
fn=load_model,