implemented xformers in HF's Llama (because theres no flash attention for Volta cards)

This commit is contained in:
mrq 2024-05-04 13:07:45 -05:00
parent 277dcec484
commit 3dca1125f5

View File

@ -5,7 +5,7 @@ import traceback
import numpy as np import numpy as np
import re import re
from typing import Literal, overload from typing import Literal, overload, Optional, Tuple
from functools import partial from functools import partial
from einops import rearrange from einops import rearrange
@ -132,6 +132,107 @@ try:
except Exception as e: except Exception as e:
print("Error importing `mixtral` arch:", e) print("Error importing `mixtral` arch:", e)
LLAMA_ATTENTIONS = {}
try:
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
class LLamaXformersAttention(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
attn_weights = None
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
except Exception as e:
print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=True ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
if impl not in LLAMA_ATTENTIONS:
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
return model
klass = LLAMA_ATTENTIONS[impl]
for *parent, k in attentions:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
config = m.config
layer_idx = m.layer_idx
kwargs = dict(config=config, layer_idx=layer_idx)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def _create_mask(l, device): def _create_mask(l, device):
"""1 is valid region and 0 is invalid.""" """1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -431,6 +532,14 @@ class Base(nn.Module):
)) ))
elif self.arch_type == "llama": elif self.arch_type == "llama":
if n_experts <= 1: if n_experts <= 1:
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
self.model = LlamaModel(LlamaConfig( self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
@ -444,8 +553,11 @@ class Base(nn.Module):
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", attn_implementation=attention,
)) ))
if use_xformers:
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens, vocab_size =n_resp_tokens,