decoupled llama backend to avoid any funny changes from transformers, removed other backends since i dont think i'll ever bother using them

This commit is contained in:
mrq 2025-02-27 19:00:37 -06:00
parent ceecac6ffe
commit eff180248c
11 changed files with 304 additions and 1792 deletions

View File

@ -51,6 +51,7 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
* [ ] train a serviceable model for 44KHz audio (instead of 24KHz)
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] clean up the README, and document, document, document.
* [ ] cleanup the documentation again, as most of it feels like schizorambling......
* [x] extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)).
- reference model is trained against English, Japanese, French, German, Korean, and Chinese (Mandarin?).
- [x] improve multi-lingual support
@ -78,6 +79,7 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
* these features are predicated on the model being trained for it
* [ ] smarter/clever inferencing, such as:
* [x] inference *all* codebooks in one pass, rather than each level being its own discrete pass.
* `cfg.model.version >= 7` models will rely on this
* these features are predicated on the model being trained for it
* [x] "rolling" context, where the last generated sentence is the prefix for the next sentence.
* [ ] for the AR, stop inferencing sequences in the batch that has already hit its stop token
@ -86,6 +88,9 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
* [x] SIM-O requires passing the raw waveform through a speaker-similarity model
* [x] valle.cpp through llama.cpp + encodec.cpp
* extend to decode with vocos.cpp, instead, for a quality improvement
* [ ] 44KHz audio, through either DAC or `nvidia/audio-codec-44khz`
* the former has quality issues in the higher RVQ levels, but may be resolved with the experimental implementation
* the latter needs testing, as it being an FSQ codec requires extra care
## "Postmortem"

View File

@ -9,10 +9,11 @@ The beauty of a transformer, I feel, is that you can easily define any task at i
The inputs are automatically sequenced in a way that a given task requires, and the outputs are handled as per the class that extends the base model.
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
While the original paper called for a separate AR model and a NAR model, by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
* Additionally, you can even train a `NAR-len` model on top of an existing model.
Later papers for discrete TTS solutions work around the multiple codebook problem by introducing exotic interleaving patterns to work around existing problems. For all intents and purposes, these aren't necessary, as the current sequencing of prioritizng the first codebook (RVQ level 0). The remaining RVQ levels can be easily deduced from the prior level in parallel.
* Exotic solutions aren't necessary at all, as the summed embeddings can be good enough to represent the original waveform. Output codes can be inferenced in parallel with a wider head, neglecting the need to train separate levels.
## The AR (Autoregressive) Model
@ -41,8 +42,6 @@ Compared to non-autoregressive decoding, I personally feel that autoregressive e
Technically, with `cfg.model.version >= 7`, a model can be purely AR, as that version of the model encodes and decodes all codebooks of audio in a single pass.
Inferencing code is not available at the moment for this modality, but will be available in the future.
## The NAR (Non-autoregressive) Model
The NAR is responsible for generating the remaining RVQ levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.

View File

@ -39,6 +39,11 @@ With the other "hyperparameters" such as ratios for RVQ levels, tasks, etc:
* it might be needed to later prefer a more balanced distribution (such as `[0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]`) to get rid of any confidence issues in RVQ levels 1+, but I felt naively doing this harms the RVQ 0.
* `prompt_similar_p` can be pretty much whatever > `0.5`. I've stuck with either `0.75` or `0.825` to prioritize adhering closely-er to the prompt, but still have random prompts used to help the model interanlly "model" what a speaker should sound like. In theory.
The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster, while Prodigyopt keeps things stable in the long run.
* `APOLLO` needs more testing, but seemed adequate in cursory tests
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
* I honestly don't think it gives good enough results from curosry tests for this application
## Try Me
To quickly test if a configuration works, you can run `python -m vall_e.models.ar_nar --yaml="./data/config.yaml"`; a small trainer will overfit a provided utterance.

View File

@ -1,6 +1,15 @@
AVAILABLE_ARCHES = []
ERROR_ARCHES = {}
try:
from .llama import Config as LlamaConfig, Model as LlamaModel, Attention as LlamaAttention, AVAILABLE_ATTENTIONS
AVAILABLE_ARCHES.append("llama")
except Exception as e:
ERROR_ARCHES["llama"] = e
AVAILABLE_ATTENTIONS = []
pass
"""
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
AVAILABLE_ARCHES.append("transformer")
@ -15,7 +24,6 @@ except Exception as e:
ERROR_ARCHES["retnet"] = e
pass
"""
try:
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
AVAILABLE_ARCHES.append("retnet-ts")
@ -29,15 +37,6 @@ try:
except Exception as e:
ERROR_ARCHES["retnet-hf"] = e
pass
"""
try:
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama")
except Exception as e:
ERROR_ARCHES["llama"] = e
AVAILABLE_ATTENTIONS = []
pass
try:
from .bitnet import BitNetTransformer
@ -59,7 +58,7 @@ try:
except Exception as e:
ERROR_ARCHES["mamba"] = e
ERROR_ARCHES["mamba2"] = e
"""
"""
try:
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig

View File

@ -1,53 +0,0 @@
# https://github.com/kyegomez/BitNet
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# re-enable logging because zetascale fucking sucks
import logging
logging.getLogger().setLevel(logging.DEBUG)
# override for wrapping checkpointing
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
skip = x
for attn, ffn in zip(self.layers, self.ffn_layers):
if x.requires_grad and self.gradient_checkpointing:
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
else:
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
x = x + skip
x = ffn(x) + x
return x
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
gradient_checkpointing = True
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
self.transformer.gradient_checkpointing = gradient_checkpointing
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""

View File

@ -1,46 +1,157 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
import math
import torch
import logging
import random
from typing import Literal, overload, Optional, Tuple, Union, List
from torch import Tensor, nn
from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
# lazy
from transformers.models.llama.configuration_llama import LlamaConfig as Config
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.activations import ACT2FN
from .attention import *
LN_2 = 0.69314718056
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs):
self.mode = kwargs.pop("mode", "sdpa")
if n_rep == 1:
return hidden_states
if self.mode == "math":
self.mode = torch.nn.attention.SDPBackend.MATH
elif self.mode == "mem_efficient":
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
elif self.mode == "flash_(sdpa)":
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
elif self.mode == "sdpa":
self.mode = torch.nn.attention.SDPBackend.MATH
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
super().__init__(*args, **kwargs)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
if not hasattr(self, "num_heads"):
self.num_heads = self.config.num_attention_heads
if not hasattr(self, "num_key_value_heads"):
self.num_key_value_heads = self.config.num_key_value_heads
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class RotaryEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, config, layer_idx, mode = "default"):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.attn_mode = mode
if self.attn_mode == "math":
self.attn_mode = torch.nn.attention.SDPBackend.MATH
elif self.attn_mode == "mem_efficient":
self.attn_mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
elif self.attn_mode == "flash_(sdpa)":
self.attn_mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.attn_mode == "cudnn":
self.attn_mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
elif self.attn_mode == "sdpa":
self.attn_mode = torch.nn.attention.SDPBackend.MATH
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
# extracts inputs from a batch based on requested causality
def split_forward(
@ -115,7 +226,7 @@ class LlamaAttention_Adapted(LlamaAttention):
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
mode = "default" if output_attentions else self.mode
mode = "default" if output_attentions else self.attn_mode
non_split_attention = [
"default",
torch.nn.attention.SDPBackend.MATH,
@ -207,34 +318,10 @@ class LlamaAttention_Adapted(LlamaAttention):
attn_scores = None
if mode in ["xformers", "flash_attn"]:
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
"""
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
"""
if mode == "flash_attn":
attn_output = flash_attn_func(
query_states,
@ -311,7 +398,7 @@ class LlamaAttention_Adapted(LlamaAttention):
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# is_causal = True if x_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False
with torch.nn.attention.sdpa_kernel(self.mode):
with torch.nn.attention.sdpa_kernel(self.attn_mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -332,28 +419,33 @@ class LlamaAttention_Adapted(LlamaAttention):
return attn_output, attn_scores, past_key_value
class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
# apply timestep embedding with attention norm
# I don't have a concrete idea on how helpful this is, as:
# * F5-TTS's UNetT implementation doesn't do this
# * F5-TTS's DiT does this, but only for pre-attention normalization
# * MaskGCT does this for both
# * Muse doesn't do this, but instead appends the timestep embedding
def weigh_by_timestep(
self,
hidden_states,
timesteps,
):
if timesteps is None:
return hidden_states
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
for i, timestep in enumerate( timesteps ):
# invalid
if not isinstance( timestep, torch.Tensor ):
continue
hidden_states[i] *= timestep
return hidden_states
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Attention(config=config, layer_idx=layer_idx)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
@ -366,35 +458,11 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
timesteps: Optional[list] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
# ugh
if isinstance( is_causal, list ) and len(is_causal) == 1:
@ -418,8 +486,6 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
@ -433,64 +499,24 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
return outputs
class LlamaModel_Adapted(LlamaModel):
def __init__(self, config, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
self.early_exit_r = kwargs.pop("early_exit_r", 2)
#super().__init__(*args, **kwargs)
super(LlamaModel, self).__init__(config)
class Model(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers_n = config.num_hidden_layers
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def dropoff_layer( self, l ):
if not self.training or self.layer_dropout_p <= 0:
return False
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1
P = D * self.layer_dropout_p
return random.random() < P
def cirriculum( self, l, t=None ):
# no timestep data passed, just treat all layers as enabled
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
if t is None:
return 1
# YUCK
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
for i in range(self.early_exit_r):
if l == ((t % self.layers_n) + i * (self.layers_n // self.early_exit_r)) % self.layers_n:
return 1
return 0
def early_exit_loss( self, losses, t=None ):
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
def normalized_per_layer_loss_scale( self, l, t=None ):
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / sum([ self.cirriculum(i, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])
def early_exit_factor( self, l ):
if 0 <= l and l < self.layers_n:
return self.early_exit_scale * sum([ i for i in range(0, l) ])
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
def _update_noncausal_mask(
self,
@ -514,6 +540,41 @@ class LlamaModel_Adapted(LlamaModel):
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
def _update_causal_mask(
self,
@ -523,31 +584,9 @@ class LlamaModel_Adapted(LlamaModel):
past_key_values: Cache,
output_attentions: bool,
):
"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
"""
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
"""
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
"""
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
@ -576,9 +615,6 @@ class LlamaModel_Adapted(LlamaModel):
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
@ -596,10 +632,7 @@ class LlamaModel_Adapted(LlamaModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
layer_skip_lambda = None,
timesteps: Optional[list] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -617,11 +650,6 @@ class LlamaModel_Adapted(LlamaModel):
)
use_cache = False
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
"""
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
@ -646,17 +674,6 @@ class LlamaModel_Adapted(LlamaModel):
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
if is_causal is not None:
"""
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=inputs_embeds.shape[1],
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
cache_position=cache_position,
batch_size=inputs_embeds.shape[0],
)
"""
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
@ -690,7 +707,6 @@ class LlamaModel_Adapted(LlamaModel):
use_cache,
cache_position,
position_embeddings,
timesteps,
)
else:
layer_outputs = decoder_layer(
@ -703,22 +719,16 @@ class LlamaModel_Adapted(LlamaModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
timesteps=timesteps,
)
if not self.dropoff_layer( l ):
hidden_states = layer_outputs[0]
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# check if we should early-exit
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
#_logger.info(f"Early exit at layer: {l}")
break
hidden_states = self.norm(hidden_states)

View File

@ -1,51 +0,0 @@
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.models.mamba.modeling_mamba import MambaModel
"""
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
"""
"""
from mamba2_torch.modeling.configuration_mamba2 import Mamba2Config
from mamba2_torch.modeling.modeling_mamba2 import Mamba2Model
"""
from fla.models.mamba2.configuration_mamba2 import Mamba2Config
from fla.models.mamba2.modeling_mamba2 import Mamba2Model
"""
# https://github.com/state-spaces/mamba
from torch.utils.checkpoint import checkpoint
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
if hidden_states is None:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
"""

View File

@ -1,796 +0,0 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
import math
import torch
import torch.nn.functional as F
from typing import Literal, overload, Optional, Tuple, List, Union
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, MixtralDecoderLayer, MixtralRMSNorm, repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.processing_utils import Unpack
from .attention import *
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# This is required because batch sizes > 1 throws errors
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
#hidden_states = hidden_states.view(-1, hidden_dim)
hidden_states = hidden_states.reshape(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
class MixtralAttention_Adapted(MixtralAttention):
def __init__(self, *args, **kwargs):
self.mode = kwargs.pop("mode", "sdpa")
if self.mode == "math":
self.mode = torch.nn.attention.SDPBackend.MATH
elif self.mode == "mem_efficient":
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
elif self.mode == "flash_(sdpa)":
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
super().__init__(*args, **kwargs)
# extracts inputs from a batch based on requested causality
def split_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
is_causal: Optional[list] = None,
target_causal_state: Optional[bool] = True,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
):
indices = [ i for i, state in enumerate( is_causal ) if state == target_causal_state ]
# no matching inputs in batch
if not indices:
return indices, None, None, None
# entire batch is homogenous
if len( indices ) == hidden_states.shape[0]:
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_causal=target_causal_state,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
input_hidden_states = torch.stack( [ hidden_states[i] for i in indices ] )
input_attention_mask = torch.stack( [ attention_mask[i] for i in indices ] ) if attention_mask is not None else None
input_position_ids = torch.stack( [ position_ids[i] for i in indices ] ) if position_ids is not None else None
input_position_embeddings = (
torch.stack( [ position_embeddings[0][i] for i in indices ] ),
torch.stack( [ position_embeddings[1][i] for i in indices ] ),
) if position_embeddings is not None else None
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
hidden_states=input_hidden_states,
attention_mask=input_attention_mask,
is_causal=target_causal_state,
position_ids=input_position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=False,
cache_position=cache_position,
position_embeddings=input_position_embeddings,
**kwargs,
)
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = True,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
mode = "default" if output_attentions else self.mode
non_split_attention = [
"default",
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
]
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
if isinstance( is_causal, list ) and mode not in non_split_attention:
# initialize lists
attn_hidden_states = [ None for _ in is_causal ]
self_attn_weights = [ None for _ in is_causal ]
present_key_values = [ None for _ in is_causal ]
# process causal inputs in a batch
causal_indices, causal_hidden_states, causal_self_attn_weights, causal_present_key_values = self.split_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_causal=is_causal,
target_causal_state=True,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
# process non-causal inputs in a batch
non_causal_indices, non_causal_hidden_states, non_causal_self_attn_weights, non_causal_present_key_values = self.split_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_causal=is_causal,
target_causal_state=False,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=False,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
# insert causal outputs to batch
for i, idx in enumerate( causal_indices ):
attn_hidden_states[idx] = causal_hidden_states[i]
if output_attentions:
self_attn_weights[idx] = causal_self_attn_weights[i]
# insert non-causal outputs to batch
for i, idx in enumerate( non_causal_indices ):
attn_hidden_states[idx] = non_causal_hidden_states[i]
if output_attentions:
self_attn_weights[idx] = non_causal_self_attn_weights[i]
# combine list
attn_hidden_states = torch.stack( attn_hidden_states, dim=0 )
if output_attentions:
self_attn_weights = torch.stack( self_attn_weights, dim=0 )
return attn_hidden_states, output_attentions, []
dropout_rate = self.attention_dropout if self.training else 0.0
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_scores = None
if mode in ["xformers", "flash_attn"]:
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
"""
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
"""
if mode == "flash_attn":
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
causal=is_causal,
softmax_scale=1.0 / math.sqrt(self.head_dim),
dropout_p=dropout_rate,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
elif mode == "xformers":
attn_output = memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias = LowerTriangularMask(),
scale = 1.0 / math.sqrt(self.head_dim),
p=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_scores, past_key_value
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
x_mask = attention_mask
if attention_mask is not None:
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and x_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
if mode in ["sageattn"]:
attn_output = sageattn(
query_states,
key_states,
value_states,
tensor_layout="HND",
is_causal=is_causal
)
elif mode in ["fused_attn"]:
attn_output = fused_attn_func(
query_states,
key_states,
value_states,
causal=is_causal,
softmax_scale=1.0 / math.sqrt(self.head_dim),
dropout_p=dropout_rate,
)
elif mode in ["default"]:
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# cringe logic
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
else:
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# is_causal = True if x_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
# cringe
if attn_scores is None and output_attentions:
attn_scores = attn_output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_scores, past_key_value
class MixtralDecoderLayer_Adapted(MixtralDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
is_causal: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
is_causal=is_causal,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs
class MixtralRotaryEmbedding(torch.nn.Module):
def __init__(self, config: MixtralConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MixtralModel_Adapted(MixtralModel):
def __init__(self, config: MixtralConfig):
#super().__init__(config)
super(MixtralModel, self).__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = torch.nn.ModuleList(
[MixtralDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = MixtralRotaryEmbedding(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _update_noncausal_mask(
self,
attention_mask,
inputs_embeds,
past_key_values_length,
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
# generate default mask based on input
if attention_mask is None:
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
# make square
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
"""
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
"""
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
"""
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
#causal_mask = self._update_causal_mask(
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
#)
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
if is_causal is not None:
"""
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=inputs_embeds.shape[1],
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
cache_position=cache_position,
batch_size=inputs_embeds.shape[0],
)
"""
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
else:
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
x_mask,
is_causal,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=x_mask,
is_causal=is_causal,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
return output if return_dict else output.to_tuple()

View File

@ -1,46 +0,0 @@
# https://github.com/microsoft/torchscale
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
# from retnet import RetNet
# override MultiScaleRetention's forward because training with te throws an error
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift
def MultiScaleRetention_forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x) * self.scaling
v = self.v_proj(x)
g = self.g_proj(x)
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
MultiScaleRetention.forward = MultiScaleRetention_forward

View File

@ -1,217 +0,0 @@
"""
# https://github.com/enhuiz/vall-e/
"""
import math
import torch
import torch.nn.functional as F
import traceback
from typing import Literal, overload
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint
from ...utils import ml
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y
class SinusoidalEmbedding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
exponent = torch.arange(self.d_half, dtype=torch.float32)
exponent = exponent / self.d_half
omega = torch.exp(-math.log(1e4) * exponent)
self.omega: torch.Tensor
self.register_buffer("omega", omega, persistent=False)
@property
def d_half(self):
assert self.d_model % 2 == 0, "Only support even d_model."
return self.d_model // 2
def forward(self, x):
"""
Args:
x: (...)
Returns:
pe: (... d)
"""
omega = self.omega
while omega.dim() <= x.dim():
omega = omega.unsqueeze(0) # (... d)
x = x.unsqueeze(-1) # (... 1)
x = omega * x
x = torch.cat([x.sin(), x.cos()], dim=-1)
return x
def get_pe(self, n: int):
"""
Args:
n: int
Returns:
pe: (n d)
"""
device = self.omega.device
return self.forward(torch.arange(n, device=device))
def add_pe(self, x):
"""
Args:
x: (b t c)
"""
e = self.get_pe(x.shape[1]) # t d
e = e[None] # b t d
x = x + e
return x
class Attention(nn.Module):
def __init__(self, d_model, n_heads, causal):
super().__init__()
assert d_model % n_heads == 0
dim_head = d_model // n_heads
self.causal = causal
self.n_heads = n_heads
self.scale = dim_head**-0.5
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
self.to_out = ml.Linear(d_model, d_model)
def forward(self, x, m):
"""
Args:
x: (b t c)
m: (b t c), 1 is data, 0 is padding
Returns:
x: (b t c)
"""
h = self.n_heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
e = einsum("b i h d, b j h d -> b i j h", q, k)
e = e * self.scale
kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
if self.causal:
with ml.autocast(kpm, torch.bfloat16, torch.float16) as k:
kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1
e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
a = e.softmax(dim=2) # Normalize on j, i.e. key
o = einsum("b i j h, b j h d -> b i h d", a, v)
o = o.flatten(-2)
o = self.to_out(o) # b t c
o = o * m
return o
class PrenormResidual(nn.Module):
def __init__(
self,
block,
d_model,
p_dropout,
requires_mask=False,
norm_type="ln",
n_levels: int | None = None,
):
super().__init__()
self.block = block
self.requires_mask = requires_mask
self.norm_type = norm_type
if norm_type == "ln":
self.norm = nn.LayerNorm(d_model)
elif norm_type == "adaln":
assert n_levels is not None
self.norm = AdaLN(d_model, n_levels)
else:
raise NotImplementedError(norm_type)
self.dropout = nn.Dropout(p_dropout)
def forward(self, x, m, l):
"""
Args:
x: input (b t d)
m: mask (b t 1), 1 is valuable and 0 is padding
l: level to use, required only for AdaLN
"""
nopts = {"l": l} if self.norm_type == "adaln" else {}
bopts = {"m": m} if self.requires_mask else {}
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
return x * m
class Block(nn.Sequential):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
super().__init__()
self.activation_checkpointing = activation_checkpointing
self.attn = PrenormResidual(
Attention(d_model, n_heads, causal),
d_model=d_model,
p_dropout=p_dropout,
requires_mask=True,
norm_type=norm_type,
n_levels=n_levels,
)
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
self.ffn = PrenormResidual(
nn.Sequential(
ml.Linear(d_model, n_ff),
nn.GELU(),
nn.Dropout(p_dropout),
ml.Linear(n_ff, d_model),
),
d_model=d_model,
p_dropout=p_dropout,
norm_type=norm_type,
n_levels=n_levels,
)
def forward(self, x, m, l):
"""
Args:
x: (b t c)
m: (b t 1)
l: (b)
"""
if x.requires_grad and self.activation_checkpointing:
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
else:
x = self.attn(x, m, l)
x = self.ffn(x, m, l)
return x

View File

@ -47,10 +47,6 @@ Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions'
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
@ -128,73 +124,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
return torch.concat( [ i.t() for i in input ] ).t().flatten()
# Deprecated implementation
class MultiEmbedding(nn.Module):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
super().__init__()
self.monolithic = monolithic
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0:
return []
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
if self.monolithic:
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
else:
w = self.weight
padded_x_list = []
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
x_list = x.split([*map(len, x_list)])
return x_list
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# _Old, to preserve compat with previous models.
class AudioEmbedding_Old(nn.Module):
def __init__(
self,
l_embedding_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
):
super().__init__()
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
# prom
if quant_level is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# prom / AR resp
elif quant_level is None or quant_level == 0:
x = self.embeddings[0]( xi if xi.dim() == 1 else xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
return x
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
class AudioEmbedding(nn.Module):
@ -235,27 +164,6 @@ class AudioEmbedding(nn.Module):
return x
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding(nn.Module):
def __init__(
self,
d_model
):
super().__init__()
self.emb = SinusoidalEmbedding(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.SiLU(),
nn.Linear(d_model*4, d_model),
)
def forward( self, t ):
t = self.emb(t)
t = self.mlp(t)
return t
# per-level classification
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers(nn.Module):
@ -380,38 +288,6 @@ class Base(nn.Module):
return l[: indices.min().item()]
# these probably need to live in an interleaved model, as pattern-ing is targeted for a sole AR model
"""
def codes_to_pattern(self, codes):
# expand if not batched
if codes.dim() == 2:
codes = codes.unsqueeze(0)
# [batch, timestep, rvq level] (B, T, K) => [batch, rvq level, timestep] (B, K, T)
codes = codes.permute(0, 2, 1)
B, K, T = codes.shape
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
pattern = self.pattern_provider.get_pattern(T)
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
codes.contiguous(), self.stop_token, keep_only_valid_steps=False,
)
# (B, K, T) => (B, T, K)
return sequence_codes.permute(0, 2, 1)
def logits_from_pattern(self, logits, pattern):
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
logits, float('nan'), keep_only_valid_steps=False
)
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
return logits, logits_mask
"""
def __init__(
self,
@ -453,13 +329,13 @@ class Base(nn.Module):
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
self.capabilities = self.config.capabilities if self.config else ["ar", "nar", "len"]
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.stop_token = self.n_audio_tokens
self.mask_token = self.stop_token
self.causal = "ar" in self.capabilities or "len" in self.capabilities
self.version = self.config.version if self.config is not None else 5
self.causal = True
self.version = self.config.version if self.config is not None else 6
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
self.arch_type = self.config.arch_type if self.config is not None else "llama"
@ -500,32 +376,11 @@ class Base(nn.Module):
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
n_tones = self.config.tones if self.config is not None else 1
# pure AR
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
@ -564,65 +419,28 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
self.dropout_token = nn.Parameter(torch.randn(d_model))
if self.version == 1: # legacy
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
self.audio_emb = None
elif self.version < 5:
# [1024] * 8
self.proms_emb = AudioEmbedding_Old(
[n_audio_tokens] * self.n_resp_levels, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old(
l_embedding_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
self.audio_emb = None
else:
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
)
self.resps_emb = AudioEmbedding(
l_embedding_tokens, d_model,
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
l_embedding_names=l_embedding_names,
)
self.audio_emb = None
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
)
self.resps_emb = AudioEmbedding(
l_embedding_tokens, d_model,
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
l_embedding_names=l_embedding_names,
)
if self.version >= 3:
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.capabilities += ["lang"]
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.capabilities += ["lang"]
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
if self.version >= 4:
self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None
self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None
# mamba requires this if a model does both AR and NAR tasks
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
# "len" RVQ level-0 gets an additional token
if self.version < 7:
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model)
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
self.len_emb = Embedding(11, d_model)
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
if attention_backend == "auto":
attention_backend = "sdpa"
"""
if AVAILABLE_ATTENTIONS:
attention_backend = AVAILABLE_ATTENTIONS[0]
else:
attention_backend = "default"
"""
hf_attention = attention_backend
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
@ -638,132 +456,21 @@ class Base(nn.Module):
elif attention_backend == "fused_attn":
self.l_padding = 128
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout if training else 0.0,
causal=self.causal,
norm_type="ln", # adaln
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type in ["llama", "mistral", "mixtral"]:
if n_experts <= 1:
self.model = LlamaModel_Adapted(LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
#sliding_window=75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
"""
# replace with desired attention
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
"""
else:
self.model = MixtralModel_Adapted(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
#sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
"""
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
"""
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_vocab,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.gradient_checkpointing,
activation_fn="gelu",
use_layernorm=self.version < 3,
use_biases=self.version < 3,
use_glu=self.version >= 3,
chunkwise_recurrent=self.causal and self.causal_size > 0,
recurrent_chunkwise_size=self.causal_size if self.causal else 0,
no_output_layer=True,
decoder_normalize_before=True,
rotary_embedding_base=10000
)
if n_experts > 1:
kwargs.update(dict(
use_xmoe=True,
moe_freq=1,
moe_expert_count=n_experts,
moe_gating_use_fp32=False,
))
self.model = RetNetDecoder(RetNetConfig(**kwargs))
elif self.arch_type in ["mamba2"]:
self.model = Mamba2Model(Mamba2Config(
vocab_size=n_vocab,
hidden_size=d_model,
expand=2,
num_hidden_layers=n_layers*2,
residual_in_fp32=True,
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif self.arch_type in ["mamba"]:
self.model = MambaModel(MambaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
expand=2,
num_hidden_layers=n_layers*2,
residual_in_fp32=True,
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if hasattr( self.model, "embeddings" ):
del self.model.embeddings
self.model = LlamaModel(LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
@ -794,77 +501,38 @@ class Base(nn.Module):
hidden_states = None
# HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]:
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
past_key_values=state,
position_ids=position_ids,
use_cache=False, # not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
is_causal=is_causal,
)
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
past_key_values=state,
position_ids=position_ids,
use_cache=False, # not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
is_causal=is_causal,
)
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
output = self.model(**kwargs)
x = output["last_hidden_state"]
# to-do: figure out why KV caching doesn't work
#if not self.training:
if state is not None:
state = output["past_key_values"]
output = self.model(**kwargs)
x = output["last_hidden_state"]
# to-do: figure out why KV caching doesn't work
#if not self.training:
if state is not None:
state = output["past_key_values"]
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training:
router_logits = output["router_logits"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
elif self.arch_type == "transformer":
# ensures we specify a quant_level for the transformer implementation's AdaLN
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
# inject position information
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
for block in self.blocks:
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type in ["mamba","mamba2"]:
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
#cache_params=state,
use_cache=False, # not self.training,
#position_ids=position_ids,
#output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
output = self.model(**kwargs)
x = output["last_hidden_state"]
if state is not None:
state = output["cache_params"]
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training:
router_logits = output["router_logits"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
# process it into a format that I like
if output_hidden_states:
@ -1210,20 +878,6 @@ class Base(nn.Module):
quant_level
)
else:
"""
offset = 0
if "nar" not in self.capabilities:
offset = 0
elif quant_level > 0:
offset = 1
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = offset,
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
)
"""
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
#offset = 0 if classifier_level.startswith("AR:") else 1,
@ -1477,6 +1131,9 @@ class Base(nn.Module):
if name == "resp":
name = f'{name}[{quant_level}]'
"""
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if nll is not None:
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
@ -1690,7 +1347,7 @@ class Base(nn.Module):
logits = [ logit[-self.causal_size:] for logit in logits ]
# (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities:
if quant_levels is not None:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR-len) disable extraneous tokens
"""