very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough)

This commit is contained in:
mrq 2024-11-02 11:49:05 -05:00
parent 62fe5b0943
commit ded746e157
6 changed files with 110 additions and 43 deletions

View File

@ -14,10 +14,10 @@ from torch.nn.utils.rnn import pad_sequence
import random import random
import math import math
import time
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
from time import perf_counter
import logging import logging
@ -66,6 +66,7 @@ class AR_NAR(Base):
sampling_dry_base=1.75, sampling_dry_base=1.75,
sampling_dry_allowed_length=2, sampling_dry_allowed_length=2,
sampling_entropix=False, sampling_entropix=False,
sampling_layer_skip: bool = False, sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1, sampling_layer_skip_exit_layer: int = -1,
@ -281,6 +282,11 @@ class AR_NAR(Base):
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
""" """
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
for i, sequence in enumerate( sequence_list ): for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT # add <bos> to text for STT
if task_list[i] in text_task: if task_list[i] in text_task:
@ -329,7 +335,7 @@ class AR_NAR(Base):
inputs=inputs, inputs=inputs,
state=state, state=state,
layer_skip_exit_layer=sampling_layer_skip_exit_layer, layer_skip_variables=sampling_layer_skip_variables,
output_attentions=sampling_entropix, output_attentions=sampling_entropix,
) )
@ -360,15 +366,11 @@ class AR_NAR(Base):
r = sampled[0] r = sampled[0]
if sampled.entropy: if cfg.experimental:
metrics.append( sampled.entropy ) if sampled.entropy:
""" metrics.append( sampled.entropy )
elif sampled.confidence: elif sampled.scores:
metrics.append( sampled.confidence ) metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
"""
elif False:
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
metrics.append( p )
if mirostat is not None: if mirostat is not None:
mirostat = sampled.scores mirostat = sampled.scores
@ -402,7 +404,13 @@ class AR_NAR(Base):
if metrics: if metrics:
from ..plot import plot_sample_metrics from ..plot import plot_sample_metrics
plot_sample_metrics( metrics ) filename = "metrics"
if sampling_entropix:
filename += f'[entropix]'
if sampling_layer_skip_exit_layer >= 0:
filename += f'[{sampling_layer_skip_exit_layer+1}]'
plot_sample_metrics( metrics, filename=f'{filename}.png' )
# pick the best scoring candidate # pick the best scoring candidate
# desu this is always going to be candidate 0 # desu this is always going to be candidate 0

View File

@ -358,7 +358,8 @@ class LlamaModel_Adapted(LlamaModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
exit_layer: Optional[int] = -1,
layer_skip_lambda = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -451,7 +452,9 @@ class LlamaModel_Adapted(LlamaModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
if 0 <= exit_layer and exit_layer <= l: # check if we should early-exit
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
#_logger.info(f"Early exit at layer: {l}")
break break
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)

View File

@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
# yuck, kind of needed # yuck, kind of needed
from ..data import get_task_symmap from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states']) Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats']) LossStats = namedtuple('LossStats', ['loss', 'stats'])
""" """
@ -476,6 +477,7 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids self.unified_position_ids = unified_position_ids
self.interleave = interleave self.interleave = interleave
self.layerskip = layerskip self.layerskip = layerskip
self.special_tasks = [ "len", "stt" ]
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -827,7 +829,7 @@ class Base(nn.Module):
state = None, state = None,
layer_skip_exit_layer = -1, layer_skip_lambda = None,
output_attentions = False, output_attentions = False,
output_hidden_states = False, output_hidden_states = False,
@ -846,7 +848,7 @@ class Base(nn.Module):
inputs_embeds=x, inputs_embeds=x,
past_key_values=state, past_key_values=state,
position_ids=position_ids, position_ids=position_ids,
use_cache=not self.training, use_cache=False, # not self.training,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=True, return_dict=True,
@ -855,8 +857,8 @@ class Base(nn.Module):
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True kwargs["output_router_logits"] = True
if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers: if self.layerskip and layer_skip_lambda is not None:
kwargs["exit_layer"] = layer_skip_exit_layer kwargs["layer_skip_lambda"] = layer_skip_lambda
output = self.model(**kwargs) output = self.model(**kwargs)
x = output["last_hidden_state"] x = output["last_hidden_state"]
@ -938,14 +940,6 @@ class Base(nn.Module):
# but skip the last state, as it already is normalized # but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * mask
if output.hidden_states:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
return Logits(x, state, aux_loss, attentions, hidden_states) return Logits(x, state, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
@ -965,8 +959,6 @@ class Base(nn.Module):
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
special_tasks = ["stt", "len"]
inputs = [ [] for _ in range(batch_size) ] inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size): for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0 quant_level = quant_levels[i] if quant_levels is not None else 0
@ -981,7 +973,7 @@ class Base(nn.Module):
# Base-line TTS task # Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp> # Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX # prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks: if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
@ -1259,9 +1251,8 @@ class Base(nn.Module):
stats = dict(acc = dict()) stats = dict(acc = dict())
device = logits[0].device device = logits[0].device
special_tasks = [ "len", "stt" ]
summed_embeddings_task = [ "stt" ] summed_embeddings_task = [ "stt" ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ): def prompt_input_to_token( input, quant_level ):
@ -1443,11 +1434,46 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None, state: dict | list | None = None,
layer_skip_exit_layer: int = -1,
layer_skip_variables: dict | None = None,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
): ):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_quant_levels and mask
def layer_skip_lambda( layer, logits ):
kwargs = {
"logits_entropy": 0.1,
"logits_varentropy": 0.1,
"min_layer": self.n_layers // 2,
"max_layer": self.n_layers,
}
kwargs.update( layer_skip_variables )
# don't bother on early layers
if layer < kwargs["min_layer"]:
return False
# bail if we want to force early layers
if kwargs["max_layer"] < layer:
return True
# hidden states aren't normalized
x = self.model.norm( logits )
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * m
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
# calculate metrics
metrics = calculate_entropix_metrics( logits )
# exit early if "good enough""
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
x_list = self.inputs_to_embeddings( inputs, quant_levels ) x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
@ -1459,7 +1485,8 @@ class Base(nn.Module):
if quant_levels is None: if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ] quant_levels = [ 0 for _ in range(batch_size) ]
if self.layerskip: # we only need hidden states if we're training with layerskip
if self.layerskip and training:
output_hidden_states = True output_hidden_states = True
# pad our input and mask, but retain the original length by doing it after # pad our input and mask, but retain the original length by doing it after
@ -1478,6 +1505,8 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs # needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
output = self._forward( output = self._forward(
inputs=x, inputs=x,
@ -1486,17 +1515,22 @@ class Base(nn.Module):
position_ids=position_ids, position_ids=position_ids,
output_attentions = output_attentions, output_attentions = output_attentions,
output_hidden_states = output_hidden_states, output_hidden_states = output_hidden_states,
layer_skip_exit_layer = layer_skip_exit_layer, layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
) )
logits = output.logits logits = output.logits
hidden_states = output.hidden_states hidden_states = output.hidden_states
# output projection layer with masking
if self.classifier is not None:
logits = self.classifier(logits) * m
if output.hidden_states:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
# to-do: piece-wise classification, now that there's a head for text # to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead...... # although again, one single monolithic head would be preferable instead......
if self.classifiers is not None: if self.classifiers is not None:
special_tasks = [ "len", "stt" ]
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
logits = self.classifiers(logits, levels = classifier_quant_levels) * m logits = self.classifiers(logits, levels = classifier_quant_levels) * m
if hidden_states is not None: if hidden_states is not None:
@ -1508,7 +1542,6 @@ class Base(nn.Module):
if hidden_states is not None: if hidden_states is not None:
for i, state in enumerate( hidden_states ): for i, state in enumerate( hidden_states ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# compute loss if the target is given # compute loss if the target is given
@ -1573,6 +1606,8 @@ class Base(nn.Module):
# other # other
attentions=None, attentions=None,
): ):
batch_size = len( logits )
if min_temperature < 0: if min_temperature < 0:
min_temperature = temperature min_temperature = temperature
@ -1598,6 +1633,14 @@ class Base(nn.Module):
if res: if res:
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ]) return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
"""
elif quant_levels is None:
seq_lens = [ logit.shape[0] for logit in logits ]
entropy = [ calculate_entropix_metrics(
logit[:seq_lens[batch], :], # ( seq_len, vocab )
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
) for batch, logit in enumerate(logits) ]
"""
# (NAR) return the entire generated response # (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
@ -1666,4 +1709,10 @@ class Base(nn.Module):
else: else:
res = [ Categorical(logits=logit).sample() for logit in logits ] res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities
scores = [
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
for logit, tokens in zip(logits, res)
]
return Sampled(res, scores, entropy) return Sampled(res, scores, entropy)

View File

@ -2,6 +2,7 @@
import argparse import argparse
import json import json
import time
import re import re
from pathlib import Path from pathlib import Path
@ -93,7 +94,7 @@ def plot(paths, args):
#bbox_to_anchor=(1.04, 0.5), #bbox_to_anchor=(1.04, 0.5),
) )
def plot_sample_metrics( metrics ): def plot_sample_metrics( metrics, filename=None ):
""" """
fig = plt.figure() fig = plt.figure()
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second ) fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
@ -111,7 +112,11 @@ def plot_sample_metrics( metrics ):
#bbox_to_anchor=(1.04, 0.5), #bbox_to_anchor=(1.04, 0.5),
) )
out_path = cfg.rel_path / "metrics.png" if not filename:
filename = f'{time.time()}.png'
out_path = cfg.rel_path / "metrics" / filename
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, bbox_inches="tight") plt.savefig(out_path, bbox_inches="tight")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -426,6 +426,7 @@ def sample_entropix(
top_p=1.0, top_p=1.0,
min_p=0.0, min_p=0.0,
cfg=EntropixSamplerConfig(), cfg=EntropixSamplerConfig(),
metrics_only=False,
): ):
""" """
temperature = cfg.temp temperature = cfg.temp

View File

@ -193,7 +193,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true") parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true") parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"]) parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav') tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -384,7 +384,7 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"): with gr.Tab("Sampler Settings"):
with gr.Row(): with gr.Row():
@ -411,9 +411,10 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
with gr.Row(): with gr.Row():
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.") layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
layout["inference_tts"]["buttons"]["inference"].click( layout["inference_tts"]["buttons"]["inference"].click(