2024-06-06 01:30:43 +00:00
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
2024-08-19 06:03:35 +00:00
import math
2024-06-06 01:30:43 +00:00
import torch
2024-08-29 18:27:16 +00:00
import logging
2024-06-06 01:30:43 +00:00
from typing import Literal , overload , Optional , Tuple
from torch import Tensor , nn
from transformers . cache_utils import Cache
from transformers import LlamaModel , LlamaConfig , LlamaForCausalLM
2024-08-04 00:51:00 +00:00
from transformers . models . llama . modeling_llama import LlamaAttention , apply_rotary_pos_emb , repeat_kv
2024-06-06 01:30:43 +00:00
2024-08-29 18:27:16 +00:00
_logger = logging . getLogger ( __name__ )
2024-08-07 01:42:39 +00:00
AVAILABLE_ATTENTIONS = [ ]
2024-08-04 00:51:00 +00:00
2024-08-09 00:38:55 +00:00
try :
from transformers . utils import is_flash_attn_2_available
if is_flash_attn_2_available ( ) :
AVAILABLE_ATTENTIONS . append ( " flash_attention_2 " )
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error while querying for `flash_attention_2` support: { str ( e ) } " )
2024-08-19 01:51:14 +00:00
try :
2024-08-27 00:13:34 +00:00
from . attention . fused import attention as _fused_attention
def fused_attn_func ( q , k , v , softmax_scale = None , causal = False , * args , * * kwargs ) :
return _fused_attention ( q , k , v , causal , softmax_scale )
AVAILABLE_ATTENTIONS . append ( " fused_attn " )
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error while querying for `fused_attn` support: { str ( e ) } " )
2024-08-26 01:07:27 +00:00
2024-08-27 00:13:34 +00:00
is_rocm = any ( " AMD " in torch . cuda . get_device_properties ( i ) . name for i in range ( torch . cuda . device_count ( ) ) )
is_ampere_or_newer_gpu = any ( torch . cuda . get_device_properties ( i ) . major > = 8 for i in range ( torch . cuda . device_count ( ) ) )
2024-08-26 01:07:27 +00:00
2024-08-27 00:13:34 +00:00
try :
if is_rocm :
# requires pain to set up on Navi3, and for no backwards (training) support
from flash_attn import flash_attn_func
2024-08-26 01:07:27 +00:00
AVAILABLE_ATTENTIONS . append ( " flash_attn " )
2024-08-27 00:13:34 +00:00
2024-08-26 01:07:27 +00:00
elif not is_ampere_or_newer_gpu :
2024-08-19 04:54:33 +00:00
# Uses https://github.com/ZRayZzz/flash-attention-v100/
# Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise...
from flash_attn_v100 import flash_attn_func as flash_attn_v100_func
2024-08-19 01:51:14 +00:00
AVAILABLE_ATTENTIONS . append ( " flash_attn " )
2024-08-19 04:54:33 +00:00
AVAILABLE_ATTENTIONS . append ( " flash_attn_v100 " ) # needed to signal to use padding
def flash_attn_func ( q , k , v , softmax_scale = None , causal = False , * args , * * kwargs ) :
return flash_attn_v100_func (
q ,
k ,
v ,
softmax_scale ,
causal
2024-08-19 01:51:14 +00:00
)
2024-08-19 04:54:33 +00:00
else :
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
# Adapted to provide flash_attn_v1 support
import flash_attn
flash_attn_ver = [ int ( t ) for t in flash_attn . __version__ . split ( " . " ) if t . isdigit ( ) ]
if flash_attn_ver < = [ 1 , 0 , 9 ] :
AVAILABLE_ATTENTIONS . append ( " flash_attn " )
from flash_attn . flash_attn_interface import flash_attn_unpadded_func
from einops import rearrange
# converts the flash_attn_2 calling convention to flash_attn_1's
def flash_attn_func ( q , k , v , dropout_p = 0.0 , softmax_scale = None , causal = False , return_attn_probs = False , deterministic = False , * args , * * kwargs ) :
batch_size , seqlen_q = q . shape [ 0 ] , q . shape [ 1 ]
seqlen_k = k . shape [ 1 ]
q , k , v = [ rearrange ( x , ' b s ... -> (b s) ... ' ) . contiguous ( ) for x in [ q , k , v ] ]
cu_seqlens_q = torch . arange ( 0 , ( batch_size + 1 ) * seqlen_q , step = seqlen_q , dtype = torch . int32 , device = q . device )
cu_seqlens_k = cu_seqlens_q
return flash_attn_unpadded_func (
q , k , v ,
cu_seqlens_q , cu_seqlens_k , seqlen_q , seqlen_k ,
dropout_p , softmax_scale , causal , return_attn_probs , deterministic
)
has_flash_attn = True
elif [ 2 , 2 , 1 ] < = flash_attn_ver < [ 2 , 5 , 7 ] :
AVAILABLE_ATTENTIONS . append ( " flash_attn " )
from flash_attn import flash_attn_func
has_flash_attn = True
elif [ 2 , 5 , 7 ] < = flash_attn_ver :
AVAILABLE_ATTENTIONS . append ( " flash_attn " )
from flash_attn import flash_attn_func , flash_attn_with_kvcache
signature = list ( inspect . signature ( flash_attn_func ) . parameters )
has_flash_attn_with_window = " window_size " in signature
has_flash_attn_with_softcap = " softcap " in signature
import flash_attn_2_cuda as flash_attn_cuda
has_flash_attn = True
has_flash_attn_with_paged = True
2024-08-19 01:51:14 +00:00
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error while querying for `flash_attn` support: { str ( e ) } " )
2024-08-09 00:38:55 +00:00
try :
from xformers . ops . fmha import memory_efficient_attention
2024-08-19 06:03:35 +00:00
from xformers . ops . fmha . attn_bias import LowerTriangularFromBottomRightMask , LowerTriangularMask
2024-08-09 00:38:55 +00:00
AVAILABLE_ATTENTIONS . append ( " xformers " )
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error while importing `xformers`: { str ( e ) } " )
2024-08-09 00:38:55 +00:00
2024-08-26 01:07:27 +00:00
# to-do: find a better way to query for if there's available kernels since these return true regardless
2024-08-04 00:51:00 +00:00
if torch . backends . cuda . flash_sdp_enabled ( ) :
2024-10-18 21:55:00 +00:00
AVAILABLE_ATTENTIONS . append ( " flash_(sdpa) " )
2024-08-04 00:51:00 +00:00
if torch . backends . cuda . mem_efficient_sdp_enabled ( ) :
AVAILABLE_ATTENTIONS . append ( " mem_efficient " )
if torch . backends . cuda . math_sdp_enabled ( ) :
AVAILABLE_ATTENTIONS . append ( " math " )
2024-06-06 01:30:43 +00:00
2024-08-07 01:42:39 +00:00
if torch . backends . cuda . cudnn_sdp_enabled ( ) :
AVAILABLE_ATTENTIONS . append ( " cudnn " )
if AVAILABLE_ATTENTIONS :
2024-10-12 16:27:55 +00:00
AVAILABLE_ATTENTIONS . append ( " sdpa " )
AVAILABLE_ATTENTIONS . append ( " default " )
2024-08-07 01:42:39 +00:00
2024-08-04 00:51:00 +00:00
class LlamaAttention_Adapted ( LlamaAttention ) :
2024-06-06 01:30:43 +00:00
def __init__ ( self , * args , * * kwargs ) :
if ' mode ' in kwargs :
self . mode = kwargs [ ' mode ' ]
kwargs . pop ( " mode " )
else :
2024-10-13 04:53:13 +00:00
self . mode = " sdpa "
2024-06-06 01:30:43 +00:00
2024-08-04 00:51:00 +00:00
if self . mode == " math " :
self . mode = torch . nn . attention . SDPBackend . MATH
elif self . mode == " mem_efficient " :
self . mode = torch . nn . attention . SDPBackend . EFFICIENT_ATTENTION
2024-10-18 21:55:00 +00:00
elif self . mode == " flash_(sdpa) " :
2024-08-04 00:51:00 +00:00
self . mode = torch . nn . attention . SDPBackend . FLASH_ATTENTION
elif self . mode == " cudnn " :
self . mode = torch . nn . attention . SDPBackend . CUDNN_ATTENTION
2024-06-06 01:30:43 +00:00
super ( ) . __init__ ( * args , * * kwargs )
2024-08-19 06:03:35 +00:00
# Adapted from LlamaAttention.forward
2024-08-04 00:51:00 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value : Optional [ Cache ] = None ,
output_attentions : bool = False ,
use_cache : bool = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
position_embeddings : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] = None , # will become mandatory in v4.45
* * kwargs ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
2024-10-13 05:26:46 +00:00
mode = " default " if output_attentions else self . mode
2024-08-19 06:03:35 +00:00
dropout_rate = self . attention_dropout if self . training else 0.0
2024-08-04 00:51:00 +00:00
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
key_states = self . k_proj ( hidden_states )
value_states = self . v_proj ( hidden_states )
query_states = query_states . view ( bsz , q_len , self . num_heads , self . head_dim ) . transpose ( 1 , 2 )
key_states = key_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
value_states = value_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
if position_embeddings is None :
cos , sin = self . rotary_emb ( value_states , position_ids )
else :
cos , sin = position_embeddings
query_states , key_states = apply_rotary_pos_emb ( query_states , key_states , cos , sin )
if past_key_value is not None :
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position }
key_states , value_states = past_key_value . update ( key_states , value_states , self . layer_idx , cache_kwargs )
2024-10-13 16:02:24 +00:00
attn_scores = None
2024-10-13 05:26:46 +00:00
if mode in [ " xformers " , " flash_attn " ] :
2024-08-19 06:03:35 +00:00
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . transpose ( 1 , 2 )
value_states = value_states . transpose ( 1 , 2 )
"""
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states . dtype
if input_dtype == torch . float32 :
if torch . is_autocast_enabled ( ) :
target_dtype = torch . get_autocast_gpu_dtype ( )
# Handle the case where the model is quantized
elif hasattr ( self . config , " _pre_quantization_dtype " ) :
target_dtype = self . config . _pre_quantization_dtype
else :
target_dtype = self . q_proj . weight . dtype
query_states = query_states . to ( target_dtype )
key_states = key_states . to ( target_dtype )
value_states = value_states . to ( target_dtype )
"""
2024-10-13 05:26:46 +00:00
if mode == " flash_attn " :
2024-08-19 06:03:35 +00:00
attn_output = flash_attn_func (
query_states ,
key_states ,
value_states ,
causal = True ,
softmax_scale = 1.0 / math . sqrt ( self . head_dim ) ,
dropout_p = dropout_rate ,
)
attn_output = attn_output . reshape ( bsz , q_len , - 1 ) . contiguous ( )
2024-10-13 05:26:46 +00:00
elif mode == " xformers " :
2024-08-19 06:03:35 +00:00
attn_output = memory_efficient_attention (
query_states ,
key_states ,
value_states ,
attn_bias = LowerTriangularMask ( ) if attention_mask is None or attention_mask [ 0 , 0 , 0 , 1 ] == 0 else None ,
scale = 1.0 / math . sqrt ( self . head_dim ) ,
p = dropout_rate
)
attn_output = attn_output . reshape ( bsz , q_len , self . hidden_size )
attn_output = self . o_proj ( attn_output )
2024-10-12 17:09:17 +00:00
return attn_output , attn_scores , past_key_value
2024-08-19 06:03:35 +00:00
2024-08-04 00:51:00 +00:00
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
2024-10-13 04:53:13 +00:00
2024-08-04 00:51:00 +00:00
causal_mask = attention_mask
if attention_mask is not None :
causal_mask = causal_mask [ : , : , : , : key_states . shape [ - 2 ] ]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states . device . type == " cuda " and causal_mask is not None :
query_states = query_states . contiguous ( )
key_states = key_states . contiguous ( )
value_states = value_states . contiguous ( )
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
2024-08-19 06:03:35 +00:00
2024-10-13 05:26:46 +00:00
if mode in [ " fused_attn " ] :
2024-08-27 00:13:34 +00:00
attn_output = fused_attn_func (
2024-08-04 00:51:00 +00:00
query_states ,
key_states ,
value_states ,
2024-08-27 00:13:34 +00:00
causal = True ,
softmax_scale = 1.0 / math . sqrt ( self . head_dim ) ,
2024-08-19 06:03:35 +00:00
dropout_p = dropout_rate ,
2024-08-04 00:51:00 +00:00
)
2024-10-13 05:26:46 +00:00
elif mode in [ " default " ] :
attn_scores = torch . matmul ( query_states , key_states . transpose ( 2 , 3 ) ) / math . sqrt ( self . head_dim )
# cringe logic
attn_weights = ( attn_scores + causal_mask ) if attention_mask is not None else ( attn_scores )
# upcast attention to fp32
attn_weights = nn . functional . softmax ( attn_weights , dim = - 1 , dtype = torch . float32 ) . to ( query_states . dtype )
attn_weights = nn . functional . dropout ( attn_weights , p = self . attention_dropout , training = self . training )
attn_output = torch . matmul ( attn_weights , value_states )
if attn_output . size ( ) != ( bsz , self . num_heads , q_len , self . head_dim ) :
raise ValueError (
f " `attn_output` should be of size { ( bsz , self . num_heads , q_len , self . head_dim ) } , but is "
f " { attn_output . size ( ) } "
)
2024-08-27 00:13:34 +00:00
else :
with torch . nn . attention . sdpa_kernel ( self . mode ) :
attn_output = torch . nn . functional . scaled_dot_product_attention (
query_states ,
key_states ,
value_states ,
attn_mask = causal_mask ,
dropout_p = dropout_rate ,
is_causal = is_causal ,
)
2024-08-04 00:51:00 +00:00
2024-10-13 17:01:12 +00:00
# cringe
if attn_scores is None and output_attentions :
attn_scores = attn_output
2024-08-04 00:51:00 +00:00
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . view ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
2024-10-12 17:09:17 +00:00
return attn_output , attn_scores , past_key_value