very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough)
This commit is contained in:
parent
62fe5b0943
commit
ded746e157
|
@ -14,10 +14,10 @@ from torch.nn.utils.rnn import pad_sequence
|
|||
|
||||
import random
|
||||
import math
|
||||
import time
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
from time import perf_counter
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -66,6 +66,7 @@ class AR_NAR(Base):
|
|||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
sampling_entropix=False,
|
||||
|
||||
sampling_layer_skip: bool = False,
|
||||
sampling_layer_skip_exit_layer: int = -1,
|
||||
|
||||
|
@ -281,6 +282,11 @@ class AR_NAR(Base):
|
|||
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||
"""
|
||||
|
||||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
if sampling_layer_skip:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
|
||||
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
# add <bos> to text for STT
|
||||
if task_list[i] in text_task:
|
||||
|
@ -329,7 +335,7 @@ class AR_NAR(Base):
|
|||
inputs=inputs,
|
||||
state=state,
|
||||
|
||||
layer_skip_exit_layer=sampling_layer_skip_exit_layer,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
|
||||
output_attentions=sampling_entropix,
|
||||
)
|
||||
|
@ -360,15 +366,11 @@ class AR_NAR(Base):
|
|||
|
||||
r = sampled[0]
|
||||
|
||||
if sampled.entropy:
|
||||
metrics.append( sampled.entropy )
|
||||
"""
|
||||
elif sampled.confidence:
|
||||
metrics.append( sampled.confidence )
|
||||
"""
|
||||
elif False:
|
||||
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
|
||||
metrics.append( p )
|
||||
if cfg.experimental:
|
||||
if sampled.entropy:
|
||||
metrics.append( sampled.entropy )
|
||||
elif sampled.scores:
|
||||
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
|
||||
|
||||
if mirostat is not None:
|
||||
mirostat = sampled.scores
|
||||
|
@ -402,7 +404,13 @@ class AR_NAR(Base):
|
|||
|
||||
if metrics:
|
||||
from ..plot import plot_sample_metrics
|
||||
plot_sample_metrics( metrics )
|
||||
filename = "metrics"
|
||||
if sampling_entropix:
|
||||
filename += f'[entropix]'
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
||||
|
||||
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
||||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
|
|
|
@ -358,7 +358,8 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
exit_layer: Optional[int] = -1,
|
||||
|
||||
layer_skip_lambda = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -451,7 +452,9 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if 0 <= exit_layer and exit_layer <= l:
|
||||
# check if we should early-exit
|
||||
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
|
||||
#_logger.info(f"Early exit at layer: {l}")
|
||||
break
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
|
|
@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
|
|||
# yuck, kind of needed
|
||||
from ..data import get_task_symmap
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
|
||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
"""
|
||||
|
@ -476,6 +477,7 @@ class Base(nn.Module):
|
|||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.special_tasks = [ "len", "stt" ]
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -827,7 +829,7 @@ class Base(nn.Module):
|
|||
|
||||
state = None,
|
||||
|
||||
layer_skip_exit_layer = -1,
|
||||
layer_skip_lambda = None,
|
||||
|
||||
output_attentions = False,
|
||||
output_hidden_states = False,
|
||||
|
@ -846,7 +848,7 @@ class Base(nn.Module):
|
|||
inputs_embeds=x,
|
||||
past_key_values=state,
|
||||
position_ids=position_ids,
|
||||
use_cache=not self.training,
|
||||
use_cache=False, # not self.training,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
|
@ -855,8 +857,8 @@ class Base(nn.Module):
|
|||
if self.n_experts > 1 and self.training:
|
||||
kwargs["output_router_logits"] = True
|
||||
|
||||
if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers:
|
||||
kwargs["exit_layer"] = layer_skip_exit_layer
|
||||
if self.layerskip and layer_skip_lambda is not None:
|
||||
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
@ -938,14 +940,6 @@ class Base(nn.Module):
|
|||
# but skip the last state, as it already is normalized
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * mask
|
||||
|
||||
if output.hidden_states:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||
|
||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
|
@ -965,8 +959,6 @@ class Base(nn.Module):
|
|||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
special_tasks = ["stt", "len"]
|
||||
|
||||
inputs = [ [] for _ in range(batch_size) ]
|
||||
for i in range(batch_size):
|
||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||
|
@ -981,7 +973,7 @@ class Base(nn.Module):
|
|||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
||||
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -1259,9 +1251,8 @@ class Base(nn.Module):
|
|||
stats = dict(acc = dict())
|
||||
|
||||
device = logits[0].device
|
||||
special_tasks = [ "len", "stt" ]
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -1443,11 +1434,46 @@ class Base(nn.Module):
|
|||
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
state: dict | list | None = None,
|
||||
layer_skip_exit_layer: int = -1,
|
||||
|
||||
layer_skip_variables: dict | None = None,
|
||||
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
# return early if it's "good" enough"
|
||||
# lambda because we need to capture the classifier_quant_levels and mask
|
||||
def layer_skip_lambda( layer, logits ):
|
||||
kwargs = {
|
||||
"logits_entropy": 0.1,
|
||||
"logits_varentropy": 0.1,
|
||||
"min_layer": self.n_layers // 2,
|
||||
"max_layer": self.n_layers,
|
||||
}
|
||||
|
||||
kwargs.update( layer_skip_variables )
|
||||
|
||||
# don't bother on early layers
|
||||
if layer < kwargs["min_layer"]:
|
||||
return False
|
||||
# bail if we want to force early layers
|
||||
if kwargs["max_layer"] < layer:
|
||||
return True
|
||||
|
||||
# hidden states aren't normalized
|
||||
x = self.model.norm( logits )
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * m
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
|
||||
# calculate metrics
|
||||
metrics = calculate_entropix_metrics( logits )
|
||||
|
||||
# exit early if "good enough""
|
||||
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
|
@ -1459,7 +1485,8 @@ class Base(nn.Module):
|
|||
if quant_levels is None:
|
||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||
|
||||
if self.layerskip:
|
||||
# we only need hidden states if we're training with layerskip
|
||||
if self.layerskip and training:
|
||||
output_hidden_states = True
|
||||
|
||||
# pad our input and mask, but retain the original length by doing it after
|
||||
|
@ -1478,6 +1505,8 @@ class Base(nn.Module):
|
|||
|
||||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||
|
||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
|
@ -1486,17 +1515,22 @@ class Base(nn.Module):
|
|||
position_ids=position_ids,
|
||||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
layer_skip_exit_layer = layer_skip_exit_layer,
|
||||
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) * m
|
||||
|
||||
if output.hidden_states:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
if self.classifiers is not None:
|
||||
special_tasks = [ "len", "stt" ]
|
||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
|
||||
if hidden_states is not None:
|
||||
|
@ -1508,7 +1542,6 @@ class Base(nn.Module):
|
|||
|
||||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
# remove padding
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# compute loss if the target is given
|
||||
|
@ -1573,6 +1606,8 @@ class Base(nn.Module):
|
|||
# other
|
||||
attentions=None,
|
||||
):
|
||||
batch_size = len( logits )
|
||||
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
||||
|
@ -1598,6 +1633,14 @@ class Base(nn.Module):
|
|||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
||||
"""
|
||||
elif quant_levels is None:
|
||||
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||
entropy = [ calculate_entropix_metrics(
|
||||
logit[:seq_lens[batch], :], # ( seq_len, vocab )
|
||||
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
|
||||
) for batch, logit in enumerate(logits) ]
|
||||
"""
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
|
@ -1666,4 +1709,10 @@ class Base(nn.Module):
|
|||
else:
|
||||
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
# calculate token probabilities
|
||||
scores = [
|
||||
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
|
||||
return Sampled(res, scores, entropy)
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -93,7 +94,7 @@ def plot(paths, args):
|
|||
#bbox_to_anchor=(1.04, 0.5),
|
||||
)
|
||||
|
||||
def plot_sample_metrics( metrics ):
|
||||
def plot_sample_metrics( metrics, filename=None ):
|
||||
"""
|
||||
fig = plt.figure()
|
||||
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
|
||||
|
@ -111,7 +112,11 @@ def plot_sample_metrics( metrics ):
|
|||
#bbox_to_anchor=(1.04, 0.5),
|
||||
)
|
||||
|
||||
out_path = cfg.rel_path / "metrics.png"
|
||||
if not filename:
|
||||
filename = f'{time.time()}.png'
|
||||
|
||||
out_path = cfg.rel_path / "metrics" / filename
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(out_path, bbox_inches="tight")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -426,6 +426,7 @@ def sample_entropix(
|
|||
top_p=1.0,
|
||||
min_p=0.0,
|
||||
cfg=EntropixSamplerConfig(),
|
||||
metrics_only=False,
|
||||
):
|
||||
"""
|
||||
temperature = cfg.temp
|
||||
|
|
|
@ -193,7 +193,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -384,7 +384,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
@ -411,9 +411,10 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.")
|
||||
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
|
||||
|
||||
|
||||
layout["inference_tts"]["buttons"]["inference"].click(
|
||||
|
|
Loading…
Reference in New Issue
Block a user