shuffled web UI options hidden by cfg.experimental to its own tab, expose early exit selection to inferencing (it kinda works naively, still need to implement self-speculation)

This commit is contained in:
mrq 2024-11-01 21:30:06 -05:00
parent ef1c17430f
commit ec79230965
8 changed files with 68 additions and 16 deletions

View File

@ -47,6 +47,9 @@ def main():
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
@ -81,6 +84,8 @@ def main():
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
seed=args.seed,
)

View File

@ -221,6 +221,9 @@ class TTS():
#
entropix_sampling=False,
#
layer_skip=False,
layer_skip_exit_layer=-1,
#
seed = None,
out_path=None,
@ -270,6 +273,8 @@ class TTS():
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
disable_tqdm=not tqdm,
use_lora=use_lora,
@ -319,6 +324,8 @@ class TTS():
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
disable_tqdm=not tqdm,
use_lora=use_lora,

View File

@ -38,7 +38,7 @@ class AR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
@ -60,6 +60,9 @@ class AR(Base):
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
disable_tqdm=False,
use_lora=None,

View File

@ -66,6 +66,8 @@ class AR_NAR(Base):
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
disable_tqdm=False,
use_lora=None,
@ -326,6 +328,9 @@ class AR_NAR(Base):
output = super().forward(
inputs=inputs,
state=state,
layer_skip_exit_layer=sampling_layer_skip_exit_layer,
output_attentions=sampling_entropix,
)
logits, state = output.logits, output.state

View File

@ -358,7 +358,7 @@ class LlamaModel_Adapted(LlamaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
early_exit_layer: Optional[int] = -1,
exit_layer: Optional[int] = -1,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -451,6 +451,9 @@ class LlamaModel_Adapted(LlamaModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
if 0 <= exit_layer and exit_layer <= l:
break
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer

View File

@ -826,6 +826,9 @@ class Base(nn.Module):
position_ids = None,
state = None,
layer_skip_exit_layer = -1,
output_attentions = False,
output_hidden_states = False,
):
@ -848,9 +851,13 @@ class Base(nn.Module):
output_hidden_states=output_hidden_states,
return_dict=True,
)
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers:
kwargs["exit_layer"] = layer_skip_exit_layer
output = self.model(**kwargs)
x = output["last_hidden_state"]
@ -1436,8 +1443,10 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None,
output_attentions = False,
output_hidden_states = False,
layer_skip_exit_layer: int = -1,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
@ -1477,6 +1486,7 @@ class Base(nn.Module):
position_ids=position_ids,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
layer_skip_exit_layer = layer_skip_exit_layer,
)
logits = output.logits

View File

@ -35,11 +35,13 @@ class NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
max_resp_context: int = -1,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
@ -52,8 +54,15 @@ class NAR(Base):
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
disable_tqdm=False,
use_lora=None,
):
device = text_list[0].device
batch_size = len(text_list)

View File

@ -192,6 +192,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -203,6 +205,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True
if kwargs.pop("layer-skip", False):
args.layer_skip = True
tts = init_tts()
@ -236,7 +241,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
)
wav = wav.squeeze(0).cpu().numpy()
@ -372,19 +379,12 @@ with ui:
with gr.Tab("Basic Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
if cfg.experimental:
layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
if cfg.experimental:
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
if cfg.experimental:
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -403,6 +403,18 @@ with ui:
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
if cfg.experimental:
with gr.Tab("Experimental Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
with gr.Row():
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row():
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.")
layout["inference_tts"]["buttons"]["inference"].click(
fn=do_inference_tts,
@ -425,8 +437,6 @@ with ui:
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
if cfg.experimental:
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():