implemented xformers in HF's Llama (because theres no flash attention for Volta cards)
This commit is contained in:
parent
277dcec484
commit
3dca1125f5
|
@ -5,7 +5,7 @@ import traceback
|
|||
import numpy as np
|
||||
import re
|
||||
|
||||
from typing import Literal, overload
|
||||
from typing import Literal, overload, Optional, Tuple
|
||||
from functools import partial
|
||||
from einops import rearrange
|
||||
|
||||
|
@ -132,6 +132,107 @@ try:
|
|||
except Exception as e:
|
||||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
|
||||
LLAMA_ATTENTIONS = {}
|
||||
try:
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
class LLamaXformersAttention(LlamaAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||
else:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||
|
||||
attn_weights = None
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
|
||||
|
||||
except Exception as e:
|
||||
print("Error creating `LLamaXformersAttention`:", e)
|
||||
|
||||
def replace_attention( model, impl, verbose=True ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
||||
|
||||
if impl not in LLAMA_ATTENTIONS:
|
||||
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
|
||||
return model
|
||||
|
||||
klass = LLAMA_ATTENTIONS[impl]
|
||||
|
||||
for *parent, k in attentions:
|
||||
name = '.'.join(parent)
|
||||
|
||||
# copy parameters
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
config = m.config
|
||||
layer_idx = m.layer_idx
|
||||
|
||||
kwargs = dict(config=config, layer_idx=layer_idx)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
|
||||
return model
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -431,6 +532,14 @@ class Base(nn.Module):
|
|||
))
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
# ick, there has to be a better way
|
||||
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
|
||||
use_xformers = False
|
||||
|
||||
if attention == "xformers":
|
||||
use_xformers = True
|
||||
attention = None
|
||||
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
|
@ -444,8 +553,11 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
attn_implementation=attention,
|
||||
))
|
||||
|
||||
if use_xformers:
|
||||
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
|
|
Loading…
Reference in New Issue
Block a user