From ded746e15725fc8b930d5f4360f97039d099d8d1 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 2 Nov 2024 11:49:05 -0500 Subject: [PATCH] very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough) --- vall_e/models/ar_nar.py | 32 +++++++----- vall_e/models/arch/llama.py | 7 ++- vall_e/models/base.py | 97 ++++++++++++++++++++++++++++--------- vall_e/plot.py | 9 +++- vall_e/samplers.py | 1 + vall_e/webui.py | 7 +-- 6 files changed, 110 insertions(+), 43 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 6563b83..ffaf0be 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -14,10 +14,10 @@ from torch.nn.utils.rnn import pad_sequence import random import math +import time from einops import rearrange from torch import Tensor from tqdm import trange -from time import perf_counter import logging @@ -66,6 +66,7 @@ class AR_NAR(Base): sampling_dry_base=1.75, sampling_dry_allowed_length=2, sampling_entropix=False, + sampling_layer_skip: bool = False, sampling_layer_skip_exit_layer: int = -1, @@ -281,6 +282,11 @@ class AR_NAR(Base): original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay """ + sampling_layer_skip_variables = {} if sampling_layer_skip else None + + if sampling_layer_skip: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers + for i, sequence in enumerate( sequence_list ): # add to text for STT if task_list[i] in text_task: @@ -329,7 +335,7 @@ class AR_NAR(Base): inputs=inputs, state=state, - layer_skip_exit_layer=sampling_layer_skip_exit_layer, + layer_skip_variables=sampling_layer_skip_variables, output_attentions=sampling_entropix, ) @@ -360,15 +366,11 @@ class AR_NAR(Base): r = sampled[0] - if sampled.entropy: - metrics.append( sampled.entropy ) - """ - elif sampled.confidence: - metrics.append( sampled.confidence ) - """ - elif False: - p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ] - metrics.append( p ) + if cfg.experimental: + if sampled.entropy: + metrics.append( sampled.entropy ) + elif sampled.scores: + metrics.append( [ { "p": p[0] } for p in sampled.scores ] ) if mirostat is not None: mirostat = sampled.scores @@ -402,7 +404,13 @@ class AR_NAR(Base): if metrics: from ..plot import plot_sample_metrics - plot_sample_metrics( metrics ) + filename = "metrics" + if sampling_entropix: + filename += f'[entropix]' + if sampling_layer_skip_exit_layer >= 0: + filename += f'[{sampling_layer_skip_exit_layer+1}]' + + plot_sample_metrics( metrics, filename=f'{filename}.png' ) # pick the best scoring candidate # desu this is always going to be candidate 0 diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 1550def..c5ad089 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -358,7 +358,8 @@ class LlamaModel_Adapted(LlamaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - exit_layer: Optional[int] = -1, + + layer_skip_lambda = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -451,7 +452,9 @@ class LlamaModel_Adapted(LlamaModel): if output_attentions: all_self_attns += (layer_outputs[1],) - if 0 <= exit_layer and exit_layer <= l: + # check if we should early-exit + if layer_skip_lambda and layer_skip_lambda( l, hidden_states ): + #_logger.info(f"Early exit at layer: {l}") break hidden_states = self.norm(hidden_states) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 489f2ba..59565df 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding # yuck, kind of needed from ..data import get_task_symmap +# these seem more elegant than a dict Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states']) -Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict +Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) """ @@ -476,6 +477,7 @@ class Base(nn.Module): self.unified_position_ids = unified_position_ids self.interleave = interleave self.layerskip = layerskip + self.special_tasks = [ "len", "stt" ] self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -827,7 +829,7 @@ class Base(nn.Module): state = None, - layer_skip_exit_layer = -1, + layer_skip_lambda = None, output_attentions = False, output_hidden_states = False, @@ -846,7 +848,7 @@ class Base(nn.Module): inputs_embeds=x, past_key_values=state, position_ids=position_ids, - use_cache=not self.training, + use_cache=False, # not self.training, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, @@ -855,8 +857,8 @@ class Base(nn.Module): if self.n_experts > 1 and self.training: kwargs["output_router_logits"] = True - if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers: - kwargs["exit_layer"] = layer_skip_exit_layer + if self.layerskip and layer_skip_lambda is not None: + kwargs["layer_skip_lambda"] = layer_skip_lambda output = self.model(**kwargs) x = output["last_hidden_state"] @@ -938,14 +940,6 @@ class Base(nn.Module): # but skip the last state, as it already is normalized hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] - # output projection layer with masking - if self.classifier is not None: - x = self.classifier(x) * mask - - if output.hidden_states: - for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifier(hidden_states[i]) * m - return Logits(x, state, aux_loss, attentions, hidden_states) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation @@ -965,8 +959,6 @@ class Base(nn.Module): device = text_list[0].device batch_size = len(text_list) - special_tasks = ["stt", "len"] - inputs = [ [] for _ in range(batch_size) ] for i in range(batch_size): quant_level = quant_levels[i] if quant_levels is not None else 0 @@ -981,7 +973,7 @@ class Base(nn.Module): # Base-line TTS task # Sequence: # prom /may/ include tokens inside to help guide things, per SpeechX - if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks: + if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks: # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -1259,9 +1251,8 @@ class Base(nn.Module): stats = dict(acc = dict()) device = logits[0].device - special_tasks = [ "len", "stt" ] summed_embeddings_task = [ "stt" ] - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1443,11 +1434,46 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, state: dict | list | None = None, - layer_skip_exit_layer: int = -1, + + layer_skip_variables: dict | None = None, output_attentions: bool = False, output_hidden_states: bool = False, ): + # return early if it's "good" enough" + # lambda because we need to capture the classifier_quant_levels and mask + def layer_skip_lambda( layer, logits ): + kwargs = { + "logits_entropy": 0.1, + "logits_varentropy": 0.1, + "min_layer": self.n_layers // 2, + "max_layer": self.n_layers, + } + + kwargs.update( layer_skip_variables ) + + # don't bother on early layers + if layer < kwargs["min_layer"]: + return False + # bail if we want to force early layers + if kwargs["max_layer"] < layer: + return True + + # hidden states aren't normalized + x = self.model.norm( logits ) + + # output projection layer with masking + if self.classifier is not None: + x = self.classifier(x) * m + elif self.classifiers is not None: + logits = self.classifiers(logits, levels = classifier_quant_levels) * m + + # calculate metrics + metrics = calculate_entropix_metrics( logits ) + + # exit early if "good enough"" + return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"] + x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) @@ -1459,7 +1485,8 @@ class Base(nn.Module): if quant_levels is None: quant_levels = [ 0 for _ in range(batch_size) ] - if self.layerskip: + # we only need hidden states if we're training with layerskip + if self.layerskip and training: output_hidden_states = True # pad our input and mask, but retain the original length by doing it after @@ -1478,6 +1505,8 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None + + classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] output = self._forward( inputs=x, @@ -1486,17 +1515,22 @@ class Base(nn.Module): position_ids=position_ids, output_attentions = output_attentions, output_hidden_states = output_hidden_states, - layer_skip_exit_layer = layer_skip_exit_layer, + layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None, ) logits = output.logits hidden_states = output.hidden_states + # output projection layer with masking + if self.classifier is not None: + logits = self.classifier(logits) * m + + if output.hidden_states: + for i, state in enumerate( hidden_states ): + hidden_states[i] = self.classifier(hidden_states[i]) * m # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... if self.classifiers is not None: - special_tasks = [ "len", "stt" ] - classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] logits = self.classifiers(logits, levels = classifier_quant_levels) * m if hidden_states is not None: @@ -1508,7 +1542,6 @@ class Base(nn.Module): if hidden_states is not None: for i, state in enumerate( hidden_states ): - # remove padding hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] # compute loss if the target is given @@ -1573,6 +1606,8 @@ class Base(nn.Module): # other attentions=None, ): + batch_size = len( logits ) + if min_temperature < 0: min_temperature = temperature @@ -1598,6 +1633,14 @@ class Base(nn.Module): if res: return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ]) + """ + elif quant_levels is None: + seq_lens = [ logit.shape[0] for logit in logits ] + entropy = [ calculate_entropix_metrics( + logit[:seq_lens[batch], :], # ( seq_len, vocab ) + #attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len ) + ) for batch, logit in enumerate(logits) ] + """ # (NAR) return the entire generated response # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) @@ -1666,4 +1709,10 @@ class Base(nn.Module): else: res = [ Categorical(logits=logit).sample() for logit in logits ] + # calculate token probabilities + scores = [ + [ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ] + for logit, tokens in zip(logits, res) + ] + return Sampled(res, scores, entropy) \ No newline at end of file diff --git a/vall_e/plot.py b/vall_e/plot.py index d5111ae..ec5dfac 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -2,6 +2,7 @@ import argparse import json +import time import re from pathlib import Path @@ -93,7 +94,7 @@ def plot(paths, args): #bbox_to_anchor=(1.04, 0.5), ) -def plot_sample_metrics( metrics ): +def plot_sample_metrics( metrics, filename=None ): """ fig = plt.figure() fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second ) @@ -111,7 +112,11 @@ def plot_sample_metrics( metrics ): #bbox_to_anchor=(1.04, 0.5), ) - out_path = cfg.rel_path / "metrics.png" + if not filename: + filename = f'{time.time()}.png' + + out_path = cfg.rel_path / "metrics" / filename + out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, bbox_inches="tight") if __name__ == "__main__": diff --git a/vall_e/samplers.py b/vall_e/samplers.py index fc55291..c149ed4 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -426,6 +426,7 @@ def sample_entropix( top_p=1.0, min_p=0.0, cfg=EntropixSamplerConfig(), + metrics_only=False, ): """ temperature = cfg.temp diff --git a/vall_e/webui.py b/vall_e/webui.py index b8de13c..15816ae 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -193,7 +193,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--entropix-sampling", action="store_true") parser.add_argument("--layer-skip", action="store_true") - parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"]) + parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1) args, unknown = parser.parse_known_args() tmp = tempfile.NamedTemporaryFile(suffix='.wav') @@ -384,7 +384,7 @@ with ui: layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): - layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") + layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): @@ -411,9 +411,10 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") with gr.Row(): + layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") with gr.Row(): - layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.") + layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.") layout["inference_tts"]["buttons"]["inference"].click(