2024-06-06 01:30:43 +00:00
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
2024-08-27 22:02:42 +00:00
import math
2024-06-06 01:30:43 +00:00
import torch
2024-06-06 14:48:43 +00:00
import torch . nn . functional as F
2024-08-05 03:03:22 +00:00
from typing import Literal , overload , Optional , Tuple
from transformers . cache_utils import Cache
2024-06-06 01:30:43 +00:00
from transformers import MixtralModel , MixtralConfig
2024-08-05 03:03:22 +00:00
from transformers . models . mixtral . modeling_mixtral import load_balancing_loss_func , MixtralSparseMoeBlock , MixtralAttention , apply_rotary_pos_emb , repeat_kv
2024-06-06 01:30:43 +00:00
2024-08-19 01:51:14 +00:00
try :
from . llama import flash_attn_func
except Exception as e :
pass
2024-08-27 22:02:42 +00:00
try :
from . llama import fused_attn_func
except Exception as e :
pass
try :
from . llama import memory_efficient_attention
except Exception as e :
pass
2024-06-06 01:30:43 +00:00
# This is required because batch sizes > 1 throws errors
2024-06-06 14:48:43 +00:00
def MixtralSparseMoeBlock_forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
2024-06-06 01:30:43 +00:00
""" """
batch_size , sequence_length , hidden_dim = hidden_states . shape
hidden_states = hidden_states . reshape ( - 1 , hidden_dim ) # was view()
# router_logits: (batch * sequence_length, n_experts)
router_logits = self . gate ( hidden_states )
routing_weights = F . softmax ( router_logits , dim = 1 , dtype = torch . float )
routing_weights , selected_experts = torch . topk ( routing_weights , self . top_k , dim = - 1 )
routing_weights / = routing_weights . sum ( dim = - 1 , keepdim = True )
# we cast back to the input dtype
routing_weights = routing_weights . to ( hidden_states . dtype )
final_hidden_states = torch . zeros (
( batch_size * sequence_length , hidden_dim ) , dtype = hidden_states . dtype , device = hidden_states . device
)
expert_mask = torch . nn . functional . one_hot ( selected_experts , num_classes = self . num_experts ) . permute ( 2 , 1 , 0 )
for expert_idx in range ( self . num_experts ) :
expert_layer = self . experts [ expert_idx ]
idx , top_x = torch . where ( expert_mask [ expert_idx ] )
if top_x . shape [ 0 ] == 0 :
continue
top_x_list = top_x . tolist ( )
idx_list = idx . tolist ( )
current_state = hidden_states [ None , top_x_list ] . reshape ( - 1 , hidden_dim )
current_hidden_states = expert_layer ( current_state ) * routing_weights [ top_x_list , idx_list , None ]
final_hidden_states . index_add_ ( 0 , top_x , current_hidden_states . to ( hidden_states . dtype ) )
final_hidden_states = final_hidden_states . reshape ( batch_size , sequence_length , hidden_dim )
return final_hidden_states , router_logits
2024-08-05 03:03:22 +00:00
MixtralSparseMoeBlock . forward = MixtralSparseMoeBlock_forward
class MixtralAttention_Adapted ( MixtralAttention ) :
def __init__ ( self , * args , * * kwargs ) :
if ' mode ' in kwargs :
self . mode = kwargs [ ' mode ' ]
kwargs . pop ( " mode " )
else :
self . mode = " math "
if self . mode == " math " :
self . mode = torch . nn . attention . SDPBackend . MATH
elif self . mode == " mem_efficient " :
self . mode = torch . nn . attention . SDPBackend . EFFICIENT_ATTENTION
elif self . mode == " flash " :
self . mode = torch . nn . attention . SDPBackend . FLASH_ATTENTION
elif self . mode == " cudnn " :
self . mode = torch . nn . attention . SDPBackend . CUDNN_ATTENTION
super ( ) . __init__ ( * args , * * kwargs )
# Adapted from MixtralAttention.forward
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value : Optional [ Cache ] = None ,
output_attentions : bool = False ,
use_cache : bool = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
position_embeddings : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] = None , # will become mandatory in v4.45
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
if output_attentions :
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
"""
logger . warning_once (
" MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
' but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation= " eager " ` when loading the model. '
)
"""
return super ( ) . forward (
hidden_states = hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
p ast_key_value = past_key_value ,
output_attentions = output_attentions ,
use_cache = use_cache ,
)
2024-08-27 22:02:42 +00:00
dropout_rate = self . attention_dropout if self . training else 0.0
2024-08-05 03:03:22 +00:00
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
key_states = self . k_proj ( hidden_states )
value_states = self . v_proj ( hidden_states )
query_states = query_states . view ( bsz , q_len , self . num_heads , self . head_dim ) . transpose ( 1 , 2 )
key_states = key_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
value_states = value_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
kv_seq_len = key_states . shape [ - 2 ]
if past_key_value is not None :
kv_seq_len + = past_key_value . get_usable_length ( kv_seq_len , self . layer_idx )
if position_embeddings is None :
cos , sin = self . rotary_emb ( value_states , seq_len = kv_seq_len )
else :
cos , sin = position_embeddings
query_states , key_states = apply_rotary_pos_emb ( query_states , key_states , cos , sin , position_ids )
if past_key_value is not None :
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position } # Specific to RoPE models
key_states , value_states = past_key_value . update ( key_states , value_states , self . layer_idx , cache_kwargs )
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
2024-08-27 22:02:42 +00:00
if self . mode in [ " xformers " , " flash_attn " ] :
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . transpose ( 1 , 2 )
value_states = value_states . transpose ( 1 , 2 )
2024-08-05 03:03:22 +00:00
2024-08-27 22:02:42 +00:00
"""
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states . dtype
if input_dtype == torch . float32 :
if torch . is_autocast_enabled ( ) :
target_dtype = torch . get_autocast_gpu_dtype ( )
# Handle the case where the model is quantized
elif hasattr ( self . config , " _pre_quantization_dtype " ) :
target_dtype = self . config . _pre_quantization_dtype
else :
target_dtype = self . q_proj . weight . dtype
query_states = query_states . to ( target_dtype )
key_states = key_states . to ( target_dtype )
value_states = value_states . to ( target_dtype )
"""
2024-08-05 03:03:22 +00:00
2024-08-27 22:02:42 +00:00
if self . mode == " flash_attn " :
attn_output = flash_attn_func (
2024-08-19 01:51:14 +00:00
query_states ,
key_states ,
value_states ,
2024-08-27 22:02:42 +00:00
causal = True ,
softmax_scale = 1.0 / math . sqrt ( self . head_dim ) ,
dropout_p = dropout_rate ,
2024-08-19 01:51:14 +00:00
)
2024-08-27 22:02:42 +00:00
attn_output = attn_output . reshape ( bsz , q_len , - 1 ) . contiguous ( )
elif self . mode == " xformers " :
attn_output = memory_efficient_attention (
query_states ,
key_states ,
value_states ,
attn_bias = LowerTriangularMask ( ) if attention_mask is None or attention_mask [ 0 , 0 , 0 , 1 ] == 0 else None ,
scale = 1.0 / math . sqrt ( self . head_dim ) ,
p = dropout_rate
)
attn_output = attn_output . reshape ( bsz , q_len , self . hidden_size )
2024-08-05 03:03:22 +00:00
2024-08-27 22:02:42 +00:00
attn_output = self . o_proj ( attn_output )
return attn_output , None , past_key_value
2024-08-05 03:03:22 +00:00
causal_mask = attention_mask
if attention_mask is not None :
causal_mask = causal_mask [ : , : , : , : key_states . shape [ - 2 ] ]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states . device . type == " cuda " and causal_mask is not None :
query_states = query_states . contiguous ( )
key_states = key_states . contiguous ( )
value_states = value_states . contiguous ( )
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
2024-08-27 22:02:42 +00:00
if self . mode in [ " fused_attn " ] :
attn_output = fused_attn_func (
2024-08-05 03:03:22 +00:00
query_states ,
key_states ,
value_states ,
2024-08-27 22:02:42 +00:00
causal = True ,
softmax_scale = 1.0 / math . sqrt ( self . head_dim ) ,
dropout_p = dropout_rate ,
2024-08-05 03:03:22 +00:00
)
2024-08-27 22:02:42 +00:00
else :
with torch . nn . attention . sdpa_kernel ( self . mode ) :
attn_output = torch . nn . functional . scaled_dot_product_attention (
query_states ,
key_states ,
value_states ,
attn_mask = causal_mask ,
dropout_p = dropout_rate ,
is_causal = is_causal ,
)
2024-08-05 03:03:22 +00:00
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . view ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
2024-08-27 22:02:42 +00:00
return attn_output , None , past_key_value