more adjustments (adjustments of early-exit entropy/varentropy thresholds, default rep pen being 1.5, experimental refine-on-stop, etc.)

This commit is contained in:
mrq 2024-11-03 18:31:28 -06:00
parent aee08b7307
commit d229725c76
8 changed files with 95 additions and 20 deletions

View File

@ -240,8 +240,13 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor.
* `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
* `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
* `--layer-skip` (AR only) enables early-exit layer skipping if the model is confident enough (for compatible models)
* `--layer-skip-exit-layer`: (AR only) maximum layer to use (for compatbiel models)
* `--layer-skip` enables early-exit layer skipping if the model is confident enough (for compatible models)
* `--layer-skip-exit-layer`: maximum layer to use
* `--layer-skip-entropy-threshold`: the maximum the logits' entropy (confidence) needs to be before exiting early
* `--layer-skip-varentropy-threshold`: the maximum the logits' varentropy (confidence spread) needs to be before exiting early
* `--refine-on-stop`: (AR only) uses the last steps' logits for the entire final output sequence, rather than the step-by-step iterative sequence.
+ This needs experimenting with to see if there's any downside.
+ to-do: compare the probability scores with the original output sequence, and pick the best one.
### Speech-to-Text

View File

@ -33,7 +33,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.125)
parser.add_argument("--repetition-penalty", type=float, default=1.5)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
@ -49,6 +49,9 @@ def main():
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--seed", type=int, default=None)
@ -86,6 +89,9 @@ def main():
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
seed=args.seed,
)

View File

@ -115,7 +115,7 @@ class BaseConfig:
raise Exception(f'Model path does not exist: {model_path}')
# load state dict and copy its stored model config
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else []
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else []
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }

View File

@ -82,6 +82,12 @@ def main():
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true")
parser.add_argument("--seed", type=int, default=None)

View File

@ -223,6 +223,10 @@ class TTS():
#
layer_skip=False,
layer_skip_exit_layer=-1,
layer_skip_entropy_threshold=-1,
layer_skip_varentropy_threshold=-1,
#
refine_on_stop=False,
#
seed = None,
@ -275,6 +279,7 @@ class TTS():
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_refine_on_stop=refine_on_stop,
disable_tqdm=not tqdm,
use_lora=use_lora,
@ -326,6 +331,9 @@ class TTS():
sampling_entropix=entropix_sampling,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
sampling_refine_on_stop=refine_on_stop,
disable_tqdm=not tqdm,
use_lora=use_lora,
@ -338,8 +346,10 @@ class TTS():
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
#sampling_layer_skip=layer_skip,
#sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_layer_skip=layer_skip,
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
disable_tqdm=not tqdm,
use_lora=use_lora,

View File

@ -69,6 +69,10 @@ class AR_NAR(Base):
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
@ -208,7 +212,12 @@ class AR_NAR(Base):
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
level = prev_list[0].shape[-1]
@ -292,7 +301,12 @@ class AR_NAR(Base):
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
@ -377,7 +391,7 @@ class AR_NAR(Base):
if sampled.entropy:
metrics.append( sampled.entropy )
elif sampled.scores:
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
if mirostat is not None:
mirostat = sampled.scores
@ -409,6 +423,8 @@ class AR_NAR(Base):
if stopped.all().item():
break
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
if metrics:
from ..plot import plot_sample_metrics
filename = "metrics"
@ -430,6 +446,17 @@ class AR_NAR(Base):
# remove <bos>
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
if sampling_refine_on_stop:
# get how much we need to slice from the end
slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ]
# -1 for the stop token
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
# greedy sample from the sequence
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
sequence_list = refined_list
return sequence_list

View File

@ -39,7 +39,7 @@ from ..emb.qnt import encode_as_embedding
from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
@ -942,7 +942,7 @@ class Base(nn.Module):
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
return Logits(x, state, aux_loss, attentions, hidden_states)
return Logits(x, state, aux_loss, attentions, hidden_states, None)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
@ -1444,11 +1444,13 @@ class Base(nn.Module):
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_quant_levels and mask
exited_layer = self.n_layers
def layer_skip_lambda( layer, logits ):
nonlocal exited_layer
kwargs = {
"logits_entropy": 0.1,
"logits_varentropy": 0.1,
"min_layer": self.n_layers // 4,
"entropy_threshold": 0.05,
"varentropy_threshold": 0.05,
"min_layer": self.n_layers // 2,
"max_layer": self.n_layers,
}
@ -1472,9 +1474,15 @@ class Base(nn.Module):
# calculate metrics
metrics = calculate_entropix_metrics( logits )
# exit early if "good enough""
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"]
if early:
exited_layer = layer
#print( layer, early, metrics )
return early
x_list = self.inputs_to_embeddings( inputs, quant_levels )
@ -1526,7 +1534,8 @@ class Base(nn.Module):
logits = output.logits
hidden_states = output.hidden_states
# output projection layer with masking
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
@ -1582,7 +1591,7 @@ class Base(nn.Module):
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states)
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer)
def sample(
self,

View File

@ -194,6 +194,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"] if cfg.experimental else 0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"] if cfg.experimental else 0.1)
parser.add_argument("--refine-on-stop", action="store_true")
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -208,6 +211,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("layer-skip", False):
args.layer_skip = True
if kwargs.pop("refine-on-stop", False):
args.refine_on_stop = True
tts = init_tts()
@ -242,8 +248,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
)
wav = wav.squeeze(0).cpu().numpy()
@ -385,6 +394,7 @@ with ui:
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -393,7 +403,7 @@ with ui:
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.5, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():
@ -415,6 +425,8 @@ with ui:
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
layout["inference_tts"]["buttons"]["inference"].click(