decoupled llama backend to avoid any funny changes from transformers, removed other backends since i dont think i'll ever bother using them
This commit is contained in:
parent
ceecac6ffe
commit
eff180248c
|
@ -51,6 +51,7 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
|
|||
* [ ] train a serviceable model for 44KHz audio (instead of 24KHz)
|
||||
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||
* [x] clean up the README, and document, document, document.
|
||||
* [ ] cleanup the documentation again, as most of it feels like schizorambling......
|
||||
* [x] extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)).
|
||||
- reference model is trained against English, Japanese, French, German, Korean, and Chinese (Mandarin?).
|
||||
- [x] improve multi-lingual support
|
||||
|
@ -78,6 +79,7 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
|
|||
* these features are predicated on the model being trained for it
|
||||
* [ ] smarter/clever inferencing, such as:
|
||||
* [x] inference *all* codebooks in one pass, rather than each level being its own discrete pass.
|
||||
* `cfg.model.version >= 7` models will rely on this
|
||||
* these features are predicated on the model being trained for it
|
||||
* [x] "rolling" context, where the last generated sentence is the prefix for the next sentence.
|
||||
* [ ] for the AR, stop inferencing sequences in the batch that has already hit its stop token
|
||||
|
@ -86,6 +88,9 @@ The reference model (`ar+nar-llama-8`/`ar+nar-len-llama-8`):
|
|||
* [x] SIM-O requires passing the raw waveform through a speaker-similarity model
|
||||
* [x] valle.cpp through llama.cpp + encodec.cpp
|
||||
* extend to decode with vocos.cpp, instead, for a quality improvement
|
||||
* [ ] 44KHz audio, through either DAC or `nvidia/audio-codec-44khz`
|
||||
* the former has quality issues in the higher RVQ levels, but may be resolved with the experimental implementation
|
||||
* the latter needs testing, as it being an FSQ codec requires extra care
|
||||
|
||||
## "Postmortem"
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ The beauty of a transformer, I feel, is that you can easily define any task at i
|
|||
|
||||
The inputs are automatically sequenced in a way that a given task requires, and the outputs are handled as per the class that extends the base model.
|
||||
|
||||
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
While the original paper called for a separate AR model and a NAR model, by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
* Additionally, you can even train a `NAR-len` model on top of an existing model.
|
||||
|
||||
Later papers for discrete TTS solutions work around the multiple codebook problem by introducing exotic interleaving patterns to work around existing problems. For all intents and purposes, these aren't necessary, as the current sequencing of prioritizng the first codebook (RVQ level 0). The remaining RVQ levels can be easily deduced from the prior level in parallel.
|
||||
* Exotic solutions aren't necessary at all, as the summed embeddings can be good enough to represent the original waveform. Output codes can be inferenced in parallel with a wider head, neglecting the need to train separate levels.
|
||||
|
||||
## The AR (Autoregressive) Model
|
||||
|
||||
|
@ -41,8 +42,6 @@ Compared to non-autoregressive decoding, I personally feel that autoregressive e
|
|||
|
||||
Technically, with `cfg.model.version >= 7`, a model can be purely AR, as that version of the model encodes and decodes all codebooks of audio in a single pass.
|
||||
|
||||
Inferencing code is not available at the moment for this modality, but will be available in the future.
|
||||
|
||||
## The NAR (Non-autoregressive) Model
|
||||
|
||||
The NAR is responsible for generating the remaining RVQ levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.
|
||||
|
|
|
@ -39,6 +39,11 @@ With the other "hyperparameters" such as ratios for RVQ levels, tasks, etc:
|
|||
* it might be needed to later prefer a more balanced distribution (such as `[0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7]`) to get rid of any confidence issues in RVQ levels 1+, but I felt naively doing this harms the RVQ 0.
|
||||
* `prompt_similar_p` can be pretty much whatever > `0.5`. I've stuck with either `0.75` or `0.825` to prioritize adhering closely-er to the prompt, but still have random prompts used to help the model interanlly "model" what a speaker should sound like. In theory.
|
||||
|
||||
The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster, while Prodigyopt keeps things stable in the long run.
|
||||
* `APOLLO` needs more testing, but seemed adequate in cursory tests
|
||||
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
|
||||
* I honestly don't think it gives good enough results from curosry tests for this application
|
||||
|
||||
## Try Me
|
||||
|
||||
To quickly test if a configuration works, you can run `python -m vall_e.models.ar_nar --yaml="./data/config.yaml"`; a small trainer will overfit a provided utterance.
|
||||
|
|
|
@ -1,6 +1,15 @@
|
|||
AVAILABLE_ARCHES = []
|
||||
ERROR_ARCHES = {}
|
||||
|
||||
try:
|
||||
from .llama import Config as LlamaConfig, Model as LlamaModel, Attention as LlamaAttention, AVAILABLE_ATTENTIONS
|
||||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["llama"] = e
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
pass
|
||||
|
||||
"""
|
||||
try:
|
||||
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
||||
AVAILABLE_ARCHES.append("transformer")
|
||||
|
@ -15,7 +24,6 @@ except Exception as e:
|
|||
ERROR_ARCHES["retnet"] = e
|
||||
pass
|
||||
|
||||
"""
|
||||
try:
|
||||
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
|
||||
AVAILABLE_ARCHES.append("retnet-ts")
|
||||
|
@ -29,15 +37,6 @@ try:
|
|||
except Exception as e:
|
||||
ERROR_ARCHES["retnet-hf"] = e
|
||||
pass
|
||||
"""
|
||||
|
||||
try:
|
||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
|
||||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["llama"] = e
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
pass
|
||||
|
||||
try:
|
||||
from .bitnet import BitNetTransformer
|
||||
|
@ -59,7 +58,7 @@ try:
|
|||
except Exception as e:
|
||||
ERROR_ARCHES["mamba"] = e
|
||||
ERROR_ARCHES["mamba2"] = e
|
||||
|
||||
"""
|
||||
"""
|
||||
try:
|
||||
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
# https://github.com/kyegomez/BitNet
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||
|
||||
# re-enable logging because zetascale fucking sucks
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# override for wrapping checkpointing
|
||||
def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
skip = x
|
||||
for attn, ffn in zip(self.layers, self.ffn_layers):
|
||||
if x.requires_grad and self.gradient_checkpointing:
|
||||
x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False)
|
||||
else:
|
||||
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
|
||||
x = x + skip
|
||||
x = ffn(x) + x
|
||||
return x
|
||||
|
||||
BitNetTransformerBlock.forward = BitNetTransformerBlock_forward
|
||||
|
||||
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
|
||||
class BitNetTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
depth: int,
|
||||
num_tokens: int,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
gradient_checkpointing = True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
|
||||
self.norm = BitNetRMSNorm(dim)
|
||||
self.transformer.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transformer(x)
|
||||
return self.norm( x )
|
||||
|
||||
"""
|
||||
from bitnet import BitNetTransformer
|
||||
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
||||
x = self.transformer(x)
|
||||
return self.to_logits[0](x)
|
||||
|
||||
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
||||
"""
|
|
@ -1,46 +1,157 @@
|
|||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
|
||||
from typing import Literal, overload, Optional, Tuple, Union, List
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||
# lazy
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig as Config
|
||||
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from .attention import *
|
||||
|
||||
LN_2 = 0.69314718056
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
|
||||
class LlamaAttention_Adapted(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.mode = kwargs.pop("mode", "sdpa")
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
|
||||
if self.mode == "math":
|
||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||
elif self.mode == "mem_efficient":
|
||||
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
|
||||
elif self.mode == "flash_(sdpa)":
|
||||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.mode == "cudnn":
|
||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
elif self.mode == "sdpa":
|
||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
if not hasattr(self, "num_heads"):
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
if not hasattr(self, "num_key_value_heads"):
|
||||
self.num_key_value_heads = self.config.num_key_value_heads
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, device=None):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
|
||||
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config, layer_idx, mode = "default"):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
|
||||
self.attn_mode = mode
|
||||
if self.attn_mode == "math":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
||||
elif self.attn_mode == "mem_efficient":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
|
||||
elif self.attn_mode == "flash_(sdpa)":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.attn_mode == "cudnn":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
elif self.attn_mode == "sdpa":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
# extracts inputs from a batch based on requested causality
|
||||
def split_forward(
|
||||
|
@ -115,7 +226,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
mode = "default" if output_attentions else self.mode
|
||||
mode = "default" if output_attentions else self.attn_mode
|
||||
non_split_attention = [
|
||||
"default",
|
||||
torch.nn.attention.SDPBackend.MATH,
|
||||
|
@ -207,34 +318,10 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
attn_scores = None
|
||||
|
||||
if mode in ["xformers", "flash_attn"]:
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
"""
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
"""
|
||||
|
||||
if mode == "flash_attn":
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
|
@ -311,7 +398,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
with torch.nn.attention.sdpa_kernel(self.attn_mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
|
@ -332,28 +419,33 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
||||
# apply timestep embedding with attention norm
|
||||
# I don't have a concrete idea on how helpful this is, as:
|
||||
# * F5-TTS's UNetT implementation doesn't do this
|
||||
# * F5-TTS's DiT does this, but only for pre-attention normalization
|
||||
# * MaskGCT does this for both
|
||||
# * Muse doesn't do this, but instead appends the timestep embedding
|
||||
def weigh_by_timestep(
|
||||
self,
|
||||
hidden_states,
|
||||
timesteps,
|
||||
):
|
||||
if timesteps is None:
|
||||
return hidden_states
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
for i, timestep in enumerate( timesteps ):
|
||||
# invalid
|
||||
if not isinstance( timestep, torch.Tensor ):
|
||||
continue
|
||||
hidden_states[i] *= timestep
|
||||
|
||||
return hidden_states
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MLP(config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -366,35 +458,11 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
timesteps: Optional[list] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||
|
||||
# ugh
|
||||
if isinstance( is_causal, list ) and len(is_causal) == 1:
|
||||
|
@ -418,8 +486,6 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
@ -433,64 +499,24 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
|
||||
return outputs
|
||||
|
||||
class LlamaModel_Adapted(LlamaModel):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0)
|
||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
||||
|
||||
#super().__init__(*args, **kwargs)
|
||||
super(LlamaModel, self).__init__(config)
|
||||
class Model(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.layers_n = config.num_hidden_layers
|
||||
|
||||
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def dropoff_layer( self, l ):
|
||||
if not self.training or self.layer_dropout_p <= 0:
|
||||
return False
|
||||
|
||||
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
|
||||
D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1
|
||||
P = D * self.layer_dropout_p
|
||||
return random.random() < P
|
||||
|
||||
def cirriculum( self, l, t=None ):
|
||||
# no timestep data passed, just treat all layers as enabled
|
||||
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
|
||||
if t is None:
|
||||
return 1
|
||||
|
||||
# YUCK
|
||||
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
|
||||
for i in range(self.early_exit_r):
|
||||
if l == ((t % self.layers_n) + i * (self.layers_n // self.early_exit_r)) % self.layers_n:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def early_exit_loss( self, losses, t=None ):
|
||||
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
|
||||
|
||||
def normalized_per_layer_loss_scale( self, l, t=None ):
|
||||
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / sum([ self.cirriculum(i, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])
|
||||
|
||||
def early_exit_factor( self, l ):
|
||||
if 0 <= l and l < self.layers_n:
|
||||
return self.early_exit_scale * sum([ i for i in range(0, l) ])
|
||||
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
|
||||
|
||||
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||
def _update_noncausal_mask(
|
||||
self,
|
||||
|
@ -514,6 +540,41 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
|
@ -523,31 +584,9 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
"""
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
"""
|
||||
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
"""
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
"""
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
|
@ -576,9 +615,6 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
|
@ -596,10 +632,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
|
||||
layer_skip_lambda = None,
|
||||
timesteps: Optional[list] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -617,11 +650,6 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
"""
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
"""
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
|
@ -646,17 +674,6 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
|
||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||
if is_causal is not None:
|
||||
"""
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=inputs_embeds.shape[1],
|
||||
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
cache_position=cache_position,
|
||||
batch_size=inputs_embeds.shape[0],
|
||||
)
|
||||
"""
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||
|
||||
|
@ -690,7 +707,6 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
timesteps,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -703,22 +719,16 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
if not self.dropoff_layer( l ):
|
||||
hidden_states = layer_outputs[0]
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
# check if we should early-exit
|
||||
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
|
||||
#_logger.info(f"Early exit at layer: {l}")
|
||||
break
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
|
||||
from transformers.models.mamba.configuration_mamba import MambaConfig
|
||||
from transformers.models.mamba.modeling_mamba import MambaModel
|
||||
|
||||
"""
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
|
||||
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
|
||||
"""
|
||||
|
||||
"""
|
||||
from mamba2_torch.modeling.configuration_mamba2 import Mamba2Config
|
||||
from mamba2_torch.modeling.modeling_mamba2 import Mamba2Model
|
||||
"""
|
||||
|
||||
from fla.models.mamba2.configuration_mamba2 import Mamba2Config
|
||||
from fla.models.mamba2.modeling_mamba2 import Mamba2Model
|
||||
|
||||
"""
|
||||
# https://github.com/state-spaces/mamba
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||
|
||||
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
||||
if hidden_states is None:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
if self.gradient_checkpointing and hidden_states.requires_grad:
|
||||
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
|
||||
else:
|
||||
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
|
||||
if not self.fused_add_norm:
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
hidden_states = MambaLayerNormFn(
|
||||
hidden_states,
|
||||
self.norm_f.weight,
|
||||
self.norm_f.bias,
|
||||
eps=self.norm_f.eps,
|
||||
residual=residual,
|
||||
prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
MambaMixelModel.forward = MambaMixelModel_forward
|
||||
"""
|
|
@ -1,796 +0,0 @@
|
|||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Literal, overload, Optional, Tuple, List, Union
|
||||
|
||||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, MixtralDecoderLayer, MixtralRMSNorm, repeat_kv
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.processing_utils import Unpack
|
||||
|
||||
from .attention import *
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
# This is required because batch sizes > 1 throws errors
|
||||
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and self.jitter_noise > 0:
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
#hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
hidden_states = hidden_states.reshape(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
|
||||
|
||||
class MixtralAttention_Adapted(MixtralAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.mode = kwargs.pop("mode", "sdpa")
|
||||
|
||||
if self.mode == "math":
|
||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||
elif self.mode == "mem_efficient":
|
||||
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
|
||||
elif self.mode == "flash_(sdpa)":
|
||||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.mode == "cudnn":
|
||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# extracts inputs from a batch based on requested causality
|
||||
def split_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[list] = None,
|
||||
target_causal_state: Optional[bool] = True,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
):
|
||||
indices = [ i for i, state in enumerate( is_causal ) if state == target_causal_state ]
|
||||
|
||||
# no matching inputs in batch
|
||||
if not indices:
|
||||
return indices, None, None, None
|
||||
|
||||
# entire batch is homogenous
|
||||
if len( indices ) == hidden_states.shape[0]:
|
||||
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=target_causal_state,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
|
||||
|
||||
input_hidden_states = torch.stack( [ hidden_states[i] for i in indices ] )
|
||||
input_attention_mask = torch.stack( [ attention_mask[i] for i in indices ] ) if attention_mask is not None else None
|
||||
input_position_ids = torch.stack( [ position_ids[i] for i in indices ] ) if position_ids is not None else None
|
||||
input_position_embeddings = (
|
||||
torch.stack( [ position_embeddings[0][i] for i in indices ] ),
|
||||
torch.stack( [ position_embeddings[1][i] for i in indices ] ),
|
||||
) if position_embeddings is not None else None
|
||||
|
||||
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
|
||||
hidden_states=input_hidden_states,
|
||||
attention_mask=input_attention_mask,
|
||||
is_causal=target_causal_state,
|
||||
position_ids=input_position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=input_position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
|
||||
|
||||
# Adapted from LlamaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = True,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
mode = "default" if output_attentions else self.mode
|
||||
non_split_attention = [
|
||||
"default",
|
||||
torch.nn.attention.SDPBackend.MATH,
|
||||
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
]
|
||||
|
||||
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
|
||||
if isinstance( is_causal, list ) and mode not in non_split_attention:
|
||||
# initialize lists
|
||||
attn_hidden_states = [ None for _ in is_causal ]
|
||||
self_attn_weights = [ None for _ in is_causal ]
|
||||
present_key_values = [ None for _ in is_causal ]
|
||||
|
||||
# process causal inputs in a batch
|
||||
causal_indices, causal_hidden_states, causal_self_attn_weights, causal_present_key_values = self.split_forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
target_causal_state=True,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# process non-causal inputs in a batch
|
||||
non_causal_indices, non_causal_hidden_states, non_causal_self_attn_weights, non_causal_present_key_values = self.split_forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
target_causal_state=False,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# insert causal outputs to batch
|
||||
for i, idx in enumerate( causal_indices ):
|
||||
attn_hidden_states[idx] = causal_hidden_states[i]
|
||||
|
||||
if output_attentions:
|
||||
self_attn_weights[idx] = causal_self_attn_weights[i]
|
||||
|
||||
# insert non-causal outputs to batch
|
||||
for i, idx in enumerate( non_causal_indices ):
|
||||
attn_hidden_states[idx] = non_causal_hidden_states[i]
|
||||
|
||||
if output_attentions:
|
||||
self_attn_weights[idx] = non_causal_self_attn_weights[i]
|
||||
|
||||
# combine list
|
||||
attn_hidden_states = torch.stack( attn_hidden_states, dim=0 )
|
||||
if output_attentions:
|
||||
self_attn_weights = torch.stack( self_attn_weights, dim=0 )
|
||||
|
||||
return attn_hidden_states, output_attentions, []
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_scores = None
|
||||
|
||||
if mode in ["xformers", "flash_attn"]:
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
"""
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
"""
|
||||
|
||||
if mode == "flash_attn":
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal=is_causal,
|
||||
softmax_scale=1.0 / math.sqrt(self.head_dim),
|
||||
dropout_p=dropout_rate,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
elif mode == "xformers":
|
||||
attn_output = memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias = LowerTriangularMask(),
|
||||
scale = 1.0 / math.sqrt(self.head_dim),
|
||||
p=dropout_rate
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
x_mask = attention_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and x_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
if mode in ["sageattn"]:
|
||||
attn_output = sageattn(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
tensor_layout="HND",
|
||||
is_causal=is_causal
|
||||
)
|
||||
elif mode in ["fused_attn"]:
|
||||
attn_output = fused_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal=is_causal,
|
||||
softmax_scale=1.0 / math.sqrt(self.head_dim),
|
||||
dropout_p=dropout_rate,
|
||||
)
|
||||
elif mode in ["default"]:
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
# cringe logic
|
||||
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
else:
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=x_mask,
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# cringe
|
||||
if attn_scores is None and output_attentions:
|
||||
attn_scores = attn_output
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
class MixtralDecoderLayer_Adapted(MixtralDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
class MixtralRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, config: MixtralConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||
# the buffer is automatically moved, but not the original copy)
|
||||
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
class MixtralModel_Adapted(MixtralModel):
|
||||
def __init__(self, config: MixtralConfig):
|
||||
#super().__init__(config)
|
||||
super(MixtralModel, self).__init__(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[MixtralDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = MixtralRotaryEmbedding(config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _update_noncausal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create noncausal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
|
||||
# generate default mask based on input
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
|
||||
|
||||
# make square
|
||||
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
|
||||
|
||||
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||
|
||||
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
"""
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
"""
|
||||
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
"""
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
"""
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
#causal_mask = self._update_causal_mask(
|
||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
#)
|
||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||
if is_causal is not None:
|
||||
"""
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=inputs_embeds.shape[1],
|
||||
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
cache_position=cache_position,
|
||||
batch_size=inputs_embeds.shape[0],
|
||||
)
|
||||
"""
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||
|
||||
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
||||
else:
|
||||
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
x_mask,
|
||||
is_causal,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=x_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
output = MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
|
@ -1,46 +0,0 @@
|
|||
# https://github.com/microsoft/torchscale
|
||||
|
||||
from torchscale.architecture.config import RetNetConfig
|
||||
from torchscale.architecture.retnet import RetNetDecoder
|
||||
# from retnet import RetNet
|
||||
|
||||
# override MultiScaleRetention's forward because training with te throws an error
|
||||
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift
|
||||
|
||||
def MultiScaleRetention_forward(
|
||||
self,
|
||||
x,
|
||||
rel_pos,
|
||||
chunkwise_recurrent=False,
|
||||
incremental_state=None
|
||||
):
|
||||
bsz, tgt_len, _ = x.size()
|
||||
(sin, cos), inner_mask = rel_pos
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x) * self.scaling
|
||||
v = self.v_proj(x)
|
||||
g = self.g_proj(x)
|
||||
|
||||
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
||||
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
||||
|
||||
qr = theta_shift(q, sin, cos)
|
||||
kr = theta_shift(k, sin, cos)
|
||||
|
||||
if incremental_state is not None:
|
||||
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
||||
elif chunkwise_recurrent:
|
||||
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
||||
else:
|
||||
output = self.parallel_forward(qr, kr, v, inner_mask)
|
||||
|
||||
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
||||
|
||||
output = self.gate_fn(g) * output
|
||||
|
||||
output = self.out_proj(output)
|
||||
|
||||
return output
|
||||
|
||||
MultiScaleRetention.forward = MultiScaleRetention_forward
|
|
@ -1,217 +0,0 @@
|
|||
"""
|
||||
# https://github.com/enhuiz/vall-e/
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import traceback
|
||||
|
||||
from typing import Literal, overload
|
||||
from functools import partial
|
||||
from einops import rearrange
|
||||
|
||||
from torch import Tensor, einsum, nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...utils import ml
|
||||
|
||||
class AdaLN(nn.Module):
|
||||
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.emb = nn.Embedding(n_levels, d_model * 2)
|
||||
self.k = k
|
||||
self.c = c
|
||||
nn.init.zeros_(self.emb.weight)
|
||||
|
||||
def forward(self, x, l):
|
||||
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
|
||||
|
||||
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
|
||||
# performed worse than vanilla LayerNorm.
|
||||
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
|
||||
# Did they use AdaNorm inside AdaLN? (as follows)
|
||||
h = self.c * (1 - (self.k * h).detach()) * h
|
||||
|
||||
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
|
||||
y = logγ.exp() * h + β
|
||||
|
||||
return y
|
||||
|
||||
class SinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
exponent = torch.arange(self.d_half, dtype=torch.float32)
|
||||
exponent = exponent / self.d_half
|
||||
omega = torch.exp(-math.log(1e4) * exponent)
|
||||
self.omega: torch.Tensor
|
||||
self.register_buffer("omega", omega, persistent=False)
|
||||
|
||||
@property
|
||||
def d_half(self):
|
||||
assert self.d_model % 2 == 0, "Only support even d_model."
|
||||
return self.d_model // 2
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (...)
|
||||
Returns:
|
||||
pe: (... d)
|
||||
"""
|
||||
omega = self.omega
|
||||
|
||||
while omega.dim() <= x.dim():
|
||||
omega = omega.unsqueeze(0) # (... d)
|
||||
|
||||
x = x.unsqueeze(-1) # (... 1)
|
||||
x = omega * x
|
||||
x = torch.cat([x.sin(), x.cos()], dim=-1)
|
||||
|
||||
return x
|
||||
|
||||
def get_pe(self, n: int):
|
||||
"""
|
||||
Args:
|
||||
n: int
|
||||
Returns:
|
||||
pe: (n d)
|
||||
"""
|
||||
device = self.omega.device
|
||||
return self.forward(torch.arange(n, device=device))
|
||||
|
||||
def add_pe(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (b t c)
|
||||
"""
|
||||
e = self.get_pe(x.shape[1]) # t d
|
||||
e = e[None] # b t d
|
||||
x = x + e
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, d_model, n_heads, causal):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0
|
||||
dim_head = d_model // n_heads
|
||||
self.causal = causal
|
||||
self.n_heads = n_heads
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
|
||||
self.to_out = ml.Linear(d_model, d_model)
|
||||
|
||||
def forward(self, x, m):
|
||||
"""
|
||||
Args:
|
||||
x: (b t c)
|
||||
m: (b t c), 1 is data, 0 is padding
|
||||
Returns:
|
||||
x: (b t c)
|
||||
"""
|
||||
h = self.n_heads
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
|
||||
|
||||
e = einsum("b i h d, b j h d -> b i j h", q, k)
|
||||
e = e * self.scale
|
||||
|
||||
kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
|
||||
|
||||
if self.causal:
|
||||
with ml.autocast(kpm, torch.bfloat16, torch.float16) as k:
|
||||
kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1
|
||||
|
||||
e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
|
||||
a = e.softmax(dim=2) # Normalize on j, i.e. key
|
||||
|
||||
o = einsum("b i j h, b j h d -> b i h d", a, v)
|
||||
o = o.flatten(-2)
|
||||
o = self.to_out(o) # b t c
|
||||
|
||||
o = o * m
|
||||
|
||||
return o
|
||||
|
||||
class PrenormResidual(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
d_model,
|
||||
p_dropout,
|
||||
requires_mask=False,
|
||||
norm_type="ln",
|
||||
n_levels: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
self.requires_mask = requires_mask
|
||||
self.norm_type = norm_type
|
||||
if norm_type == "ln":
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
elif norm_type == "adaln":
|
||||
assert n_levels is not None
|
||||
self.norm = AdaLN(d_model, n_levels)
|
||||
else:
|
||||
raise NotImplementedError(norm_type)
|
||||
self.dropout = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, m, l):
|
||||
"""
|
||||
Args:
|
||||
x: input (b t d)
|
||||
m: mask (b t 1), 1 is valuable and 0 is padding
|
||||
l: level to use, required only for AdaLN
|
||||
"""
|
||||
nopts = {"l": l} if self.norm_type == "adaln" else {}
|
||||
bopts = {"m": m} if self.requires_mask else {}
|
||||
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
|
||||
return x * m
|
||||
|
||||
|
||||
class Block(nn.Sequential):
|
||||
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
|
||||
super().__init__()
|
||||
|
||||
self.activation_checkpointing = activation_checkpointing
|
||||
self.attn = PrenormResidual(
|
||||
Attention(d_model, n_heads, causal),
|
||||
d_model=d_model,
|
||||
p_dropout=p_dropout,
|
||||
requires_mask=True,
|
||||
norm_type=norm_type,
|
||||
n_levels=n_levels,
|
||||
)
|
||||
|
||||
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
|
||||
self.ffn = PrenormResidual(
|
||||
nn.Sequential(
|
||||
ml.Linear(d_model, n_ff),
|
||||
nn.GELU(),
|
||||
nn.Dropout(p_dropout),
|
||||
ml.Linear(n_ff, d_model),
|
||||
),
|
||||
d_model=d_model,
|
||||
p_dropout=p_dropout,
|
||||
norm_type=norm_type,
|
||||
n_levels=n_levels,
|
||||
)
|
||||
|
||||
def forward(self, x, m, l):
|
||||
"""
|
||||
Args:
|
||||
x: (b t c)
|
||||
m: (b t 1)
|
||||
l: (b)
|
||||
"""
|
||||
if x.requires_grad and self.activation_checkpointing:
|
||||
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
|
||||
else:
|
||||
x = self.attn(x, m, l)
|
||||
x = self.ffn(x, m, l)
|
||||
return x
|
|
@ -47,10 +47,6 @@ Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions'
|
|||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
"""
|
||||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||
"""
|
||||
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
|
||||
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||
|
@ -128,73 +124,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
|||
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
||||
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
||||
|
||||
# Deprecated implementation
|
||||
class MultiEmbedding(nn.Module):
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
super().__init__()
|
||||
self.monolithic = monolithic
|
||||
self.max_n_levels = max_n_levels
|
||||
self.n_tokens = n_tokens
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
|
||||
if self.monolithic:
|
||||
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
|
||||
else:
|
||||
w = self.weight
|
||||
|
||||
padded_x_list = []
|
||||
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
x = einsum("l k d, n l k -> n d", w, x)
|
||||
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
|
||||
return x_list
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
# _Old, to preserve compat with previous models.
|
||||
class AudioEmbedding_Old(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_embedding_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
if quant_level is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
# prom / AR resp
|
||||
elif quant_level is None or quant_level == 0:
|
||||
x = self.embeddings[0]( xi if xi.dim() == 1 else xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
|
||||
return x
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
# Mostly to handle some oversights and errors during testing
|
||||
class AudioEmbedding(nn.Module):
|
||||
|
@ -235,27 +164,6 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# time-step embedding
|
||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||
class TimeEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model
|
||||
):
|
||||
super().__init__()
|
||||
self.emb = SinusoidalEmbedding(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model, d_model*4),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model*4, d_model),
|
||||
)
|
||||
|
||||
def forward( self, t ):
|
||||
t = self.emb(t)
|
||||
t = self.mlp(t)
|
||||
|
||||
return t
|
||||
|
||||
# per-level classification
|
||||
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
|
||||
class Classifiers(nn.Module):
|
||||
|
@ -380,38 +288,6 @@ class Base(nn.Module):
|
|||
|
||||
return l[: indices.min().item()]
|
||||
|
||||
# these probably need to live in an interleaved model, as pattern-ing is targeted for a sole AR model
|
||||
"""
|
||||
def codes_to_pattern(self, codes):
|
||||
# expand if not batched
|
||||
if codes.dim() == 2:
|
||||
codes = codes.unsqueeze(0)
|
||||
# [batch, timestep, rvq level] (B, T, K) => [batch, rvq level, timestep] (B, K, T)
|
||||
codes = codes.permute(0, 2, 1)
|
||||
|
||||
B, K, T = codes.shape
|
||||
|
||||
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
||||
pattern = self.pattern_provider.get_pattern(T)
|
||||
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
||||
codes.contiguous(), self.stop_token, keep_only_valid_steps=False,
|
||||
)
|
||||
|
||||
# (B, K, T) => (B, T, K)
|
||||
return sequence_codes.permute(0, 2, 1)
|
||||
|
||||
def logits_from_pattern(self, logits, pattern):
|
||||
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
||||
|
||||
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
||||
logits, float('nan'), keep_only_valid_steps=False
|
||||
)
|
||||
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
||||
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
||||
|
||||
return logits, logits_mask
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
@ -453,13 +329,13 @@ class Base(nn.Module):
|
|||
|
||||
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
|
||||
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
|
||||
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
|
||||
self.capabilities = self.config.capabilities if self.config else ["ar", "nar", "len"]
|
||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
|
||||
self.stop_token = self.n_audio_tokens
|
||||
self.mask_token = self.stop_token
|
||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal = True
|
||||
self.version = self.config.version if self.config is not None else 6
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
|
||||
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
@ -500,32 +376,11 @@ class Base(nn.Module):
|
|||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
||||
# pure AR
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
|
@ -564,65 +419,28 @@ class Base(nn.Module):
|
|||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
self.dropout_token = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
self.audio_emb = None
|
||||
elif self.version < 5:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding_Old(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding_Old(
|
||||
l_embedding_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
self.audio_emb = None
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_embedding_tokens, d_model,
|
||||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
self.audio_emb = None
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_embedding_tokens, d_model,
|
||||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||
|
||||
self.capabilities += ["lang"]
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||
self.capabilities += ["lang"]
|
||||
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
|
||||
if self.version >= 4:
|
||||
self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||
self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||
|
||||
# mamba requires this if a model does both AR and NAR tasks
|
||||
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
# "len" RVQ level-0 gets an additional token
|
||||
if self.version < 7:
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
|
||||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model)
|
||||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
self.len_emb = Embedding(11, d_model)
|
||||
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
"""
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||
else:
|
||||
attention_backend = "default"
|
||||
"""
|
||||
|
||||
hf_attention = attention_backend
|
||||
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
|
||||
|
@ -638,132 +456,21 @@ class Base(nn.Module):
|
|||
elif attention_backend == "fused_attn":
|
||||
self.l_padding = 128
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
self.blocks = nn.ModuleList([TransformerBlock(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
p_dropout=p_dropout if training else 0.0,
|
||||
causal=self.causal,
|
||||
norm_type="ln", # adaln
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel_Adapted(LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*d_ffn,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
#sliding_window=75 * 12, # 12 second context window
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
"""
|
||||
# replace with desired attention
|
||||
if attention_backend not in HF_ATTENTIONS:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
"""
|
||||
else:
|
||||
self.model = MixtralModel_Adapted(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
#sliding_window=75 * 12, # 12 second context window
|
||||
output_router_logits=training,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
"""
|
||||
if attention_backend not in HF_ATTENTIONS:
|
||||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
"""
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_vocab,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_value_embed_dim =d_model * 2,
|
||||
decoder_retention_heads=n_heads,
|
||||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout if training else 0.0,
|
||||
checkpoint_activations=self.gradient_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=self.version < 3,
|
||||
use_biases=self.version < 3,
|
||||
use_glu=self.version >= 3,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.causal_size > 0,
|
||||
recurrent_chunkwise_size=self.causal_size if self.causal else 0,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
|
||||
rotary_embedding_base=10000
|
||||
)
|
||||
|
||||
if n_experts > 1:
|
||||
kwargs.update(dict(
|
||||
use_xmoe=True,
|
||||
moe_freq=1,
|
||||
moe_expert_count=n_experts,
|
||||
moe_gating_use_fp32=False,
|
||||
))
|
||||
|
||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||
elif self.arch_type in ["mamba2"]:
|
||||
self.model = Mamba2Model(Mamba2Config(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
expand=2,
|
||||
num_hidden_layers=n_layers*2,
|
||||
residual_in_fp32=True,
|
||||
))
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type in ["mamba"]:
|
||||
self.model = MambaModel(MambaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
expand=2,
|
||||
num_hidden_layers=n_layers*2,
|
||||
residual_in_fp32=True,
|
||||
))
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if hasattr( self.model, "embeddings" ):
|
||||
del self.model.embeddings
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*d_ffn,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
|
||||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
|
||||
|
@ -794,77 +501,38 @@ class Base(nn.Module):
|
|||
hidden_states = None
|
||||
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
attention_mask=m,
|
||||
past_key_values=state,
|
||||
position_ids=position_ids,
|
||||
use_cache=False, # not self.training,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
attention_mask=m,
|
||||
past_key_values=state,
|
||||
position_ids=position_ids,
|
||||
use_cache=False, # not self.training,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if self.n_experts > 1 and self.training:
|
||||
kwargs["output_router_logits"] = True
|
||||
if self.n_experts > 1 and self.training:
|
||||
kwargs["output_router_logits"] = True
|
||||
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
||||
# to-do: figure out why KV caching doesn't work
|
||||
#if not self.training:
|
||||
if state is not None:
|
||||
state = output["past_key_values"]
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
||||
# to-do: figure out why KV caching doesn't work
|
||||
#if not self.training:
|
||||
if state is not None:
|
||||
state = output["past_key_values"]
|
||||
|
||||
if output_attentions:
|
||||
attentions = output["attentions"]
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = output["hidden_states"]
|
||||
|
||||
if self.n_experts > 1 and self.training:
|
||||
router_logits = output["router_logits"]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
|
||||
|
||||
elif self.arch_type == "transformer":
|
||||
# ensures we specify a quant_level for the transformer implementation's AdaLN
|
||||
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
|
||||
l = l.to(device)
|
||||
# inject position information
|
||||
x = self.sin_emb.add_pe(x)
|
||||
# pass our inputs through the transformer
|
||||
for block in self.blocks:
|
||||
x = block(x, m, l)
|
||||
elif self.arch_type == "retnet":
|
||||
# pass our inputs through the RetNet
|
||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||
elif self.arch_type in ["mamba","mamba2"]:
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
attention_mask=m,
|
||||
#cache_params=state,
|
||||
use_cache=False, # not self.training,
|
||||
#position_ids=position_ids,
|
||||
#output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
||||
if state is not None:
|
||||
state = output["cache_params"]
|
||||
|
||||
if output_attentions:
|
||||
attentions = output["attentions"]
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = output["hidden_states"]
|
||||
if output_attentions:
|
||||
attentions = output["attentions"]
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = output["hidden_states"]
|
||||
|
||||
if self.n_experts > 1 and self.training:
|
||||
router_logits = output["router_logits"]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
|
||||
|
||||
# process it into a format that I like
|
||||
if output_hidden_states:
|
||||
|
@ -1210,20 +878,6 @@ class Base(nn.Module):
|
|||
quant_level
|
||||
)
|
||||
else:
|
||||
"""
|
||||
offset = 0
|
||||
if "nar" not in self.capabilities:
|
||||
offset = 0
|
||||
elif quant_level > 0:
|
||||
offset = 1
|
||||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = offset,
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
)
|
||||
"""
|
||||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
|
@ -1477,6 +1131,9 @@ class Base(nn.Module):
|
|||
if name == "resp":
|
||||
name = f'{name}[{quant_level}]'
|
||||
"""
|
||||
|
||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
|
||||
if nll is not None:
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
|
@ -1690,7 +1347,7 @@ class Base(nn.Module):
|
|||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
if quant_levels is not None:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
# (AR-len) disable extraneous tokens
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user