From 1d460b9fe3c7195db55ad084184446ac6928d5ee Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 8 Dec 2024 14:52:47 -0600 Subject: [PATCH] logic fixes, I feel like output is better? (also NAR can have a temperature, I imagine it couldn't because it was having a causal masked passed to it for the longest time before I caught it a month ago) --- vall_e/data.py | 2 +- vall_e/demo.py | 3 +- vall_e/inference.py | 4 -- vall_e/models/ar_nar.py | 102 +++++++++++++--------------------------- vall_e/webui.py | 69 +++++++++++++-------------- 5 files changed, 70 insertions(+), 110 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index b89a385..2995723 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -61,7 +61,7 @@ def sentence_split( s, split_by="sentences", quote_placeholder="" ): # nltk does not split quotations all that nicely, so we coerce them into placeholders, then replace afterwards s = s.replace('"', quote_placeholder) sentences = nltk.sent_tokenize(s) - return [ sentence.replace(quote_placeholder, '"') for sentence in sentences ] + return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ] @cache def get_random_prompts( validation=False, min_length=0, tokenized=False ): diff --git a/vall_e/demo.py b/vall_e/demo.py index 01e028d..659238c 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -123,6 +123,7 @@ def main(): parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) + parser.add_argument("--attention", type=str, default="auto") parser.add_argument("--random-prompts", action="store_true") parser.add_argument("--lora", action="store_true") @@ -136,7 +137,7 @@ def main(): elif args.model: config = args.model - tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp ) + tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) if not args.demo_dir: args.demo_dir = Path("./data/demo/") diff --git a/vall_e/inference.py b/vall_e/inference.py index ccea91c..5b92213 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -300,8 +300,6 @@ class TTS(): if model_len is not None: # extra kwargs duration_padding = sampling_kwargs.pop("duration_padding", 1.05) - nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) - len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens # add an additional X seconds @@ -443,8 +441,6 @@ class TTS(): if model_len is not None: # extra kwargs duration_padding = sampling_kwargs.pop("duration_padding", 1.05) - nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) - len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens # add an additional X seconds diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 1067a14..965a19a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -223,31 +223,19 @@ class AR_NAR(Base): device = text_list[0].device batch_size = len(text_list) - # special "scheduling" to inference RVQ-level 0 level = 0 if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - # to-do: check if gumbel sampling works / helps """ - def log(x, eps = 1e-20): - return torch.log(x.clamp(min = eps)) - - def gumbel_sample(x, temperature = 1., dim = -1): - return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) - """ - def log(t, eps=1e-10): return torch.log(t + eps) - - def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) - - def gumbel_sample(t, temperature=1.0, dim=-1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) + """ # convert (N)AR specific args sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) @@ -261,18 +249,18 @@ class AR_NAR(Base): # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough temperature = sampling_kwargs.pop("temperature", 0.0) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5) # this really helps keep audio coherent so far - cfg_strength = sampling_kwargs.get("cfg_strength", 2.0) + cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength) cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", True) + remasking = sampling_kwargs.get("remasking", False) max_steps = math.floor(max_steps * (end_noise - start_noise)) len_list = [ clamp(l, min_length, max_length) for l in len_list ] # force set CFG because too low / no CFG causes issues - minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0) original_cfg_strength = cfg_strength cfg_strength = max( cfg_strength, minimum_cfg_strength ) @@ -320,17 +308,19 @@ class AR_NAR(Base): time_list = [ timestep for _ in range(batch_size) ] sampling_temperature = temperature * annealing if annealed_sampling else temperature - sampling_cfg = cfg_strength * timestep if annealed_sampling else temperature + sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength # avoid useless CFG sampling + """ if sampling_cfg < minimum_cfg_strength * 0.5: sampling_cfg = 0 + """ if prefix_context is not None: input_resps_list = [ torch.concat( [ prefix, resps ] ) for prefix, resps in zip( prefix_resps_list, resps_list ) ] # originally requested no CFG, safe to ignore if we have a prefix - if original_cfg_strength == 0: - sampling_cfg = 0 + if original_cfg_strength < minimum_cfg_strength: + sampling_cfg = original_cfg_strength * timestep if annealed_sampling else original_cfg_strength else: input_resps_list = resps_list @@ -347,7 +337,6 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, ) logits = output.logits @@ -365,7 +354,6 @@ class AR_NAR(Base): null_output = super().forward( inputs=null_inputs, quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, ) logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) @@ -420,49 +408,6 @@ class AR_NAR(Base): use_lora=None, **sampling_kwargs, ): - # deduce batch_size - if text_list is not None: - default_task = "tts" - device = text_list[0].device - batch_size = len(text_list) - else: - default_task = "stt" - device = resps_list[0].device - batch_size = len(resps_list) - - # convert NAR specific args - sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" ) - - max_levels = sampling_kwargs.get("max_levels", 0) - cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) - cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) - - if max_levels == 0: - max_levels = self.n_max_levels - 1 - - # prefixed context provided - """ - prefix_context = sampling_kwargs.get("prefix_context", None) - if prefix_context is not None: - prefix_text, prefix_resps, _ = prefix_context - # to-do: check if we actually need to drop the middle "" - text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ] - # feeding this into the NAR-len should automatically handle things - resps_list = [ resps for resps in prefix_resps ] - """ - - """ - sampling_layer_skip_variables = {} if sampling_layer_skip else None - - if sampling_layer_skip: - if sampling_layer_skip_entropy_threshold >= 0: - sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold - if sampling_layer_skip_varentropy_threshold >= 0: - sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold - if sampling_layer_skip_exit_layer >= 0: - sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer - """ - # inference NAR level 0 if len_list is not None: resps_list = self.forward_nar_masked( @@ -475,6 +420,26 @@ class AR_NAR(Base): len_list=len_list, **sampling_kwargs, ) + + # deduce batch_size + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + + # convert NAR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" ) + + max_levels = sampling_kwargs.get("max_levels", 0) + cfg_strength = sampling_kwargs.get("cfg_strength", 0.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) + + if max_levels == 0: + max_levels = self.n_max_levels - 1 # expand if given a raw 1D tensor for i, resp in enumerate(resps_list): @@ -489,14 +454,14 @@ class AR_NAR(Base): iterator = trange( max_levels, desc="NAR", disable=disable_tqdm ) for n in iterator: level = prev_list[0].shape[-1] - if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels + if level >= max_levels + 1: iterator.close() break if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) + quant_levels = [ level for _ in range(batch_size) ] inputs = self.inputs( text_list=text_list, @@ -510,7 +475,6 @@ class AR_NAR(Base): output = super().forward( inputs=inputs, quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, ) logits, state = output.logits, output.state @@ -526,7 +490,6 @@ class AR_NAR(Base): null_output = super().forward( inputs=null_inputs, quant_levels=quant_levels, - #layer_skip_variables=sampling_layer_skip_variables, ) logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] for resp in resps_list ] ) @@ -535,8 +498,7 @@ class AR_NAR(Base): logits=logits, prev_list=prev_list, quant_levels=quant_levels, - #temperature=0.0, - **(sampling_kwargs | {"temperature": 0.0}), + **(sampling_kwargs), ) resps_list = sampled.ids diff --git a/vall_e/webui.py b/vall_e/webui.py index dcf587e..516852a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -440,31 +440,43 @@ with ui: with gr.Column(scale=7): with gr.Tab("Basic Settings"): with gr.Row(): - layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Duration", info="Limits how many steps to perform in the AR pass.") - layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=50, minimum=1, maximum=200, step=1, label="Max Steps (NAR-len)", info="Limits how many steps to perform in the NAR-len (demask) pass.") - layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") + layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=50, minimum=1, maximum=200, step=1, label="Max Steps", info="Limits how many steps to perform in the NAR-len (demask) pass.") + layout["inference_tts"]["inputs"]["max-duration"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Duration", info="Limits how long an utterance can be.") + layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.5, label="Input Prompt Repeat/Trim Length", info="Repeats/trims the input prompt down to X seconds (0 to disable).") with gr.Row(): - layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Modifies the randomness from the samples in the AR/NAR-len. (0 to greedy* sample)") - layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") - layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") - with gr.Row(): - layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).") - layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).") - layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="auto", info="Target language/accent to output.") layout["inference_tts"]["inputs"]["text-language"] = gr.Dropdown(choices=get_languages(), label="Language (Text)", value="auto", info="Language the input text is in.") + layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language (Output)", value="auto", info="Target language/accent to output.") with gr.Row(): - layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences") + layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences") layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).") with gr.Tab("Sampler Settings"): with gr.Row(): + layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)") + layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Adjusts the probabilities in the NAR. (0 to greedy sample)") + layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") + with gr.Row(): + layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.5, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).") + layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).") + with gr.Row(): + layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P", info="Filter out logits lower than this value.") layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") - layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.05, label="Top-nσ", info="Performs top-nσ logits processing.") - layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P", info="Filter out logits lower than this value.") + layout["inference_tts"]["inputs"]["top-no"] = gr.Slider(value=0, minimum=0, maximum=2, step=0.5, label="Top-nσ", info="Performs top-nσ logits processing.") with gr.Row(): layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + # These settings are pretty much not supported anyways + with gr.Tab("Experimental Settings", visible=cfg.experimental): + with gr.Row(): + layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") + layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") + layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.5, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") + with gr.Row(): + layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") + layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") + layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") + layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on ", info="Uses the last step's logits for the AR sequence instead.") with gr.Row(): layout["inference_tts"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference_tts"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") @@ -472,22 +484,11 @@ with ui: layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).") layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty") layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") - with gr.Tab("Experimental Settings", visible=cfg.experimental): with gr.Row(): - layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") - layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") - layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") - with gr.Row(): - layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") - layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") - layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on ", info="Uses the last step's logits for the AR sequence instead.") - with gr.Row(visible=False): layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit") layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit") - layout["inference_tts"]["buttons"]["inference"].click( fn=do_inference_tts, @@ -508,10 +509,8 @@ with ui: with gr.Tab("Basic Settings"): with gr.Row(): layout["inference_stt"]["inputs"]["ar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") - with gr.Row(): - layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") - with gr.Tab("Sampler Settings"): + layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en", info="Language of the input audio being transcribed.") + with gr.Tab("Sampler Settings", visible=False): with gr.Row(): layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") @@ -522,6 +521,7 @@ with ui: layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") with gr.Row(): + layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.") layout["inference_stt"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.") with gr.Row(): @@ -570,14 +570,15 @@ with ui: if not USING_SPACES: with gr.Tab("Settings"): with gr.Row(): - with gr.Column(scale=7): - with gr.Row(): - layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model") - layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") - layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") - layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions") with gr.Column(scale=1): layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") + with gr.Column(scale=7): + with gr.Row(): + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model", info="Model to load. Can load from a config YAML or the weights itself.") + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device", info="Device to load the weights onto.") + with gr.Row(): + layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision", info="Tensor type to load the model under.") + layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions", info="Attention mechanism to utilize.") layout["settings"]["buttons"]["load"].click( fn=load_model,