more adjustments (adjustments of early-exit entropy/varentropy thresholds, default rep pen being 1.5, experimental refine-on-stop, etc.)
This commit is contained in:
parent
aee08b7307
commit
d229725c76
|
@ -240,8 +240,13 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor.
|
||||
* `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
|
||||
* `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
|
||||
* `--layer-skip` (AR only) enables early-exit layer skipping if the model is confident enough (for compatible models)
|
||||
* `--layer-skip-exit-layer`: (AR only) maximum layer to use (for compatbiel models)
|
||||
* `--layer-skip` enables early-exit layer skipping if the model is confident enough (for compatible models)
|
||||
* `--layer-skip-exit-layer`: maximum layer to use
|
||||
* `--layer-skip-entropy-threshold`: the maximum the logits' entropy (confidence) needs to be before exiting early
|
||||
* `--layer-skip-varentropy-threshold`: the maximum the logits' varentropy (confidence spread) needs to be before exiting early
|
||||
* `--refine-on-stop`: (AR only) uses the last steps' logits for the entire final output sequence, rather than the step-by-step iterative sequence.
|
||||
+ This needs experimenting with to see if there's any downside.
|
||||
+ to-do: compare the probability scores with the original output sequence, and pick the best one.
|
||||
|
||||
### Speech-to-Text
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ def main():
|
|||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.5)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
@ -49,6 +49,9 @@ def main():
|
|||
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
|
@ -86,6 +89,9 @@ def main():
|
|||
entropix_sampling=args.entropix_sampling,
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ class BaseConfig:
|
|||
raise Exception(f'Model path does not exist: {model_path}')
|
||||
|
||||
# load state dict and copy its stored model config
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else []
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else []
|
||||
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
|
||||
|
||||
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
|
||||
|
|
|
@ -82,6 +82,12 @@ def main():
|
|||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
|
|
|
@ -223,6 +223,10 @@ class TTS():
|
|||
#
|
||||
layer_skip=False,
|
||||
layer_skip_exit_layer=-1,
|
||||
layer_skip_entropy_threshold=-1,
|
||||
layer_skip_varentropy_threshold=-1,
|
||||
#
|
||||
refine_on_stop=False,
|
||||
#
|
||||
seed = None,
|
||||
|
||||
|
@ -275,6 +279,7 @@ class TTS():
|
|||
sampling_entropix=entropix_sampling,
|
||||
sampling_layer_skip=layer_skip,
|
||||
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
|
||||
sampling_refine_on_stop=refine_on_stop,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
@ -326,6 +331,9 @@ class TTS():
|
|||
sampling_entropix=entropix_sampling,
|
||||
sampling_layer_skip=layer_skip,
|
||||
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
|
||||
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
|
||||
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
|
||||
sampling_refine_on_stop=refine_on_stop,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
@ -338,8 +346,10 @@ class TTS():
|
|||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
#sampling_layer_skip=layer_skip,
|
||||
#sampling_layer_skip_exit_layer=layer_skip_exit_layer,
|
||||
sampling_layer_skip=layer_skip,
|
||||
sampling_layer_skip_exit_layer=layer_skip_exit_layer,
|
||||
sampling_layer_skip_entropy_threshold=layer_skip_entropy_threshold,
|
||||
sampling_layer_skip_varentropy_threshold=layer_skip_varentropy_threshold,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
|
|
@ -69,6 +69,10 @@ class AR_NAR(Base):
|
|||
|
||||
sampling_layer_skip: bool = False,
|
||||
sampling_layer_skip_exit_layer: int = -1,
|
||||
sampling_layer_skip_entropy_threshold: float = -1,
|
||||
sampling_layer_skip_varentropy_threshold: float = -1,
|
||||
|
||||
sampling_refine_on_stop: bool = False,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -208,7 +212,12 @@ class AR_NAR(Base):
|
|||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
if sampling_layer_skip:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
|
||||
if sampling_layer_skip_entropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
|
||||
if sampling_layer_skip_varentropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = prev_list[0].shape[-1]
|
||||
|
@ -292,7 +301,12 @@ class AR_NAR(Base):
|
|||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
if sampling_layer_skip:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
|
||||
if sampling_layer_skip_entropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
|
||||
if sampling_layer_skip_varentropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
# add <bos> to text for STT
|
||||
|
@ -377,7 +391,7 @@ class AR_NAR(Base):
|
|||
if sampled.entropy:
|
||||
metrics.append( sampled.entropy )
|
||||
elif sampled.scores:
|
||||
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
|
||||
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
|
||||
|
||||
if mirostat is not None:
|
||||
mirostat = sampled.scores
|
||||
|
@ -409,6 +423,8 @@ class AR_NAR(Base):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
|
||||
|
||||
if metrics:
|
||||
from ..plot import plot_sample_metrics
|
||||
filename = "metrics"
|
||||
|
@ -430,6 +446,17 @@ class AR_NAR(Base):
|
|||
# remove <bos>
|
||||
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
|
||||
|
||||
if sampling_refine_on_stop:
|
||||
# get how much we need to slice from the end
|
||||
slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ]
|
||||
# -1 for the stop token
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
|
||||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
# to-do: compare scores
|
||||
# set the "refined" list as the output
|
||||
sequence_list = refined_list
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ from ..emb.qnt import encode_as_embedding
|
|||
from ..data import get_task_symmap
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
|
@ -942,7 +942,7 @@ class Base(nn.Module):
|
|||
# but skip the last state, as it already is normalized
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states, None)
|
||||
|
||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
|
@ -1444,11 +1444,13 @@ class Base(nn.Module):
|
|||
):
|
||||
# return early if it's "good" enough"
|
||||
# lambda because we need to capture the classifier_quant_levels and mask
|
||||
exited_layer = self.n_layers
|
||||
def layer_skip_lambda( layer, logits ):
|
||||
nonlocal exited_layer
|
||||
kwargs = {
|
||||
"logits_entropy": 0.1,
|
||||
"logits_varentropy": 0.1,
|
||||
"min_layer": self.n_layers // 4,
|
||||
"entropy_threshold": 0.05,
|
||||
"varentropy_threshold": 0.05,
|
||||
"min_layer": self.n_layers // 2,
|
||||
"max_layer": self.n_layers,
|
||||
}
|
||||
|
||||
|
@ -1472,9 +1474,15 @@ class Base(nn.Module):
|
|||
|
||||
# calculate metrics
|
||||
metrics = calculate_entropix_metrics( logits )
|
||||
|
||||
# exit early if "good enough""
|
||||
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
|
||||
early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"]
|
||||
|
||||
if early:
|
||||
exited_layer = layer
|
||||
|
||||
#print( layer, early, metrics )
|
||||
|
||||
return early
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
|
||||
|
@ -1526,7 +1534,8 @@ class Base(nn.Module):
|
|||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
# output projection layer with masking
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
|
||||
|
@ -1582,7 +1591,7 @@ class Base(nn.Module):
|
|||
self.stats = stats
|
||||
|
||||
# rewrap, because we're modifying the logits here
|
||||
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states)
|
||||
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -194,6 +194,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=kwargs["layer-skip-entropy-threshold"] if cfg.experimental else 0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=kwargs["layer-skip-varentropy-threshold"] if cfg.experimental else 0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -208,6 +211,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
if kwargs.pop("layer-skip", False):
|
||||
args.layer_skip = True
|
||||
|
||||
if kwargs.pop("refine-on-stop", False):
|
||||
args.refine_on_stop = True
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
|
@ -242,8 +248,11 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
dry_base=args.dry_base,
|
||||
dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
@ -385,6 +394,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
|
||||
layout["inference_tts"]["inputs"]["refine-on-stop"] = gr.Checkbox(label="Refine on <stop>", info="Uses the last step's logits for the AR sequence instead.")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
@ -393,7 +403,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.5, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
with gr.Row():
|
||||
|
@ -415,6 +425,8 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
|
||||
layout["inference_tts"]["inputs"]["layer-skip-entropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Entropy Threshold", info="Entropy threshold for early-exit")
|
||||
layout["inference_tts"]["inputs"]["layer-skip-varentropy-threshold"] = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="Layer Skip Varentropy Threshold", info="Varentropy threshold for early-exit")
|
||||
|
||||
|
||||
layout["inference_tts"]["buttons"]["inference"].click(
|
||||
|
|
Loading…
Reference in New Issue
Block a user