implemented xformers in HF's Llama (because theres no flash attention for Volta cards)

This commit is contained in:
mrq 2024-05-04 13:07:45 -05:00
parent 277dcec484
commit 3dca1125f5

View File

@ -5,7 +5,7 @@ import traceback
import numpy as np
import re
from typing import Literal, overload
from typing import Literal, overload, Optional, Tuple
from functools import partial
from einops import rearrange
@ -132,6 +132,107 @@ try:
except Exception as e:
print("Error importing `mixtral` arch:", e)
LLAMA_ATTENTIONS = {}
try:
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
class LLamaXformersAttention(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
except Exception as e:
print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=True ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
if impl not in LLAMA_ATTENTIONS:
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
return model
klass = LLAMA_ATTENTIONS[impl]
for *parent, k in attentions:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
config = m.config
layer_idx = m.layer_idx
kwargs = dict(config=config, layer_idx=layer_idx)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -431,6 +532,14 @@ class Base(nn.Module):
))
elif self.arch_type == "llama":
if n_experts <= 1:
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
@ -444,8 +553,11 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
attn_implementation=attention,
))
if use_xformers:
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,