very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough)

This commit is contained in:
mrq 2024-11-02 11:49:05 -05:00
parent 62fe5b0943
commit ded746e157
6 changed files with 110 additions and 43 deletions

View File

@ -14,10 +14,10 @@ from torch.nn.utils.rnn import pad_sequence
import random
import math
import time
from einops import rearrange
from torch import Tensor
from tqdm import trange
from time import perf_counter
import logging
@ -66,6 +66,7 @@ class AR_NAR(Base):
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
@ -281,6 +282,11 @@ class AR_NAR(Base):
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
if task_list[i] in text_task:
@ -329,7 +335,7 @@ class AR_NAR(Base):
inputs=inputs,
state=state,
layer_skip_exit_layer=sampling_layer_skip_exit_layer,
layer_skip_variables=sampling_layer_skip_variables,
output_attentions=sampling_entropix,
)
@ -360,15 +366,11 @@ class AR_NAR(Base):
r = sampled[0]
if sampled.entropy:
metrics.append( sampled.entropy )
"""
elif sampled.confidence:
metrics.append( sampled.confidence )
"""
elif False:
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
metrics.append( p )
if cfg.experimental:
if sampled.entropy:
metrics.append( sampled.entropy )
elif sampled.scores:
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
if mirostat is not None:
mirostat = sampled.scores
@ -402,7 +404,13 @@ class AR_NAR(Base):
if metrics:
from ..plot import plot_sample_metrics
plot_sample_metrics( metrics )
filename = "metrics"
if sampling_entropix:
filename += f'[entropix]'
if sampling_layer_skip_exit_layer >= 0:
filename += f'[{sampling_layer_skip_exit_layer+1}]'
plot_sample_metrics( metrics, filename=f'{filename}.png' )
# pick the best scoring candidate
# desu this is always going to be candidate 0

View File

@ -358,7 +358,8 @@ class LlamaModel_Adapted(LlamaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
exit_layer: Optional[int] = -1,
layer_skip_lambda = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -451,7 +452,9 @@ class LlamaModel_Adapted(LlamaModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
if 0 <= exit_layer and exit_layer <= l:
# check if we should early-exit
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
#_logger.info(f"Early exit at layer: {l}")
break
hidden_states = self.norm(hidden_states)

View File

@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
# yuck, kind of needed
from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
@ -476,6 +477,7 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.special_tasks = [ "len", "stt" ]
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -827,7 +829,7 @@ class Base(nn.Module):
state = None,
layer_skip_exit_layer = -1,
layer_skip_lambda = None,
output_attentions = False,
output_hidden_states = False,
@ -846,7 +848,7 @@ class Base(nn.Module):
inputs_embeds=x,
past_key_values=state,
position_ids=position_ids,
use_cache=not self.training,
use_cache=False, # not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
@ -855,8 +857,8 @@ class Base(nn.Module):
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers:
kwargs["exit_layer"] = layer_skip_exit_layer
if self.layerskip and layer_skip_lambda is not None:
kwargs["layer_skip_lambda"] = layer_skip_lambda
output = self.model(**kwargs)
x = output["last_hidden_state"]
@ -938,14 +940,6 @@ class Base(nn.Module):
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * mask
if output.hidden_states:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
return Logits(x, state, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
@ -965,8 +959,6 @@ class Base(nn.Module):
device = text_list[0].device
batch_size = len(text_list)
special_tasks = ["stt", "len"]
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
@ -981,7 +973,7 @@ class Base(nn.Module):
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -1259,9 +1251,8 @@ class Base(nn.Module):
stats = dict(acc = dict())
device = logits[0].device
special_tasks = [ "len", "stt" ]
summed_embeddings_task = [ "stt" ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -1443,11 +1434,46 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None,
layer_skip_exit_layer: int = -1,
layer_skip_variables: dict | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_quant_levels and mask
def layer_skip_lambda( layer, logits ):
kwargs = {
"logits_entropy": 0.1,
"logits_varentropy": 0.1,
"min_layer": self.n_layers // 2,
"max_layer": self.n_layers,
}
kwargs.update( layer_skip_variables )
# don't bother on early layers
if layer < kwargs["min_layer"]:
return False
# bail if we want to force early layers
if kwargs["max_layer"] < layer:
return True
# hidden states aren't normalized
x = self.model.norm( logits )
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * m
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
# calculate metrics
metrics = calculate_entropix_metrics( logits )
# exit early if "good enough""
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
@ -1459,7 +1485,8 @@ class Base(nn.Module):
if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ]
if self.layerskip:
# we only need hidden states if we're training with layerskip
if self.layerskip and training:
output_hidden_states = True
# pad our input and mask, but retain the original length by doing it after
@ -1478,6 +1505,8 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
output = self._forward(
inputs=x,
@ -1486,17 +1515,22 @@ class Base(nn.Module):
position_ids=position_ids,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
layer_skip_exit_layer = layer_skip_exit_layer,
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
)
logits = output.logits
hidden_states = output.hidden_states
# output projection layer with masking
if self.classifier is not None:
logits = self.classifier(logits) * m
if output.hidden_states:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
if self.classifiers is not None:
special_tasks = [ "len", "stt" ]
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
if hidden_states is not None:
@ -1508,7 +1542,6 @@ class Base(nn.Module):
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# compute loss if the target is given
@ -1573,6 +1606,8 @@ class Base(nn.Module):
# other
attentions=None,
):
batch_size = len( logits )
if min_temperature < 0:
min_temperature = temperature
@ -1598,6 +1633,14 @@ class Base(nn.Module):
if res:
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
"""
elif quant_levels is None:
seq_lens = [ logit.shape[0] for logit in logits ]
entropy = [ calculate_entropix_metrics(
logit[:seq_lens[batch], :], # ( seq_len, vocab )
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
) for batch, logit in enumerate(logits) ]
"""
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
@ -1666,4 +1709,10 @@ class Base(nn.Module):
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities
scores = [
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
for logit, tokens in zip(logits, res)
]
return Sampled(res, scores, entropy)

View File

@ -2,6 +2,7 @@
import argparse
import json
import time
import re
from pathlib import Path
@ -93,7 +94,7 @@ def plot(paths, args):
#bbox_to_anchor=(1.04, 0.5),
)
def plot_sample_metrics( metrics ):
def plot_sample_metrics( metrics, filename=None ):
"""
fig = plt.figure()
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
@ -111,7 +112,11 @@ def plot_sample_metrics( metrics ):
#bbox_to_anchor=(1.04, 0.5),
)
out_path = cfg.rel_path / "metrics.png"
if not filename:
filename = f'{time.time()}.png'
out_path = cfg.rel_path / "metrics" / filename
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, bbox_inches="tight")
if __name__ == "__main__":

View File

@ -426,6 +426,7 @@ def sample_entropix(
top_p=1.0,
min_p=0.0,
cfg=EntropixSamplerConfig(),
metrics_only=False,
):
"""
temperature = cfg.temp

View File

@ -193,7 +193,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -384,7 +384,7 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -411,9 +411,10 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.")
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["buttons"]["inference"].click(