implemented xformers in HF's Llama (because theres no flash attention for Volta cards)
This commit is contained in:
parent
277dcec484
commit
3dca1125f5
|
@ -5,7 +5,7 @@ import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Literal, overload
|
from typing import Literal, overload, Optional, Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
@ -132,6 +132,107 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error importing `mixtral` arch:", e)
|
print("Error importing `mixtral` arch:", e)
|
||||||
|
|
||||||
|
|
||||||
|
LLAMA_ATTENTIONS = {}
|
||||||
|
try:
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from xformers.ops import LowerTriangularMask
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
||||||
|
class LLamaXformersAttention(LlamaAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
|
||||||
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error creating `LLamaXformersAttention`:", e)
|
||||||
|
|
||||||
|
def replace_attention( model, impl, verbose=True ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
||||||
|
|
||||||
|
if impl not in LLAMA_ATTENTIONS:
|
||||||
|
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
|
||||||
|
return model
|
||||||
|
|
||||||
|
klass = LLAMA_ATTENTIONS[impl]
|
||||||
|
|
||||||
|
for *parent, k in attentions:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
# copy parameters
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
config = m.config
|
||||||
|
layer_idx = m.layer_idx
|
||||||
|
|
||||||
|
kwargs = dict(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
"""1 is valid region and 0 is invalid."""
|
"""1 is valid region and 0 is invalid."""
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
@ -431,6 +532,14 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
|
# ick, there has to be a better way
|
||||||
|
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
|
||||||
|
use_xformers = False
|
||||||
|
|
||||||
|
if attention == "xformers":
|
||||||
|
use_xformers = True
|
||||||
|
attention = None
|
||||||
|
|
||||||
self.model = LlamaModel(LlamaConfig(
|
self.model = LlamaModel(LlamaConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
|
@ -444,8 +553,11 @@ class Base(nn.Module):
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
attn_implementation=attention,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
if use_xformers:
|
||||||
|
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
vocab_size =n_resp_tokens,
|
vocab_size =n_resp_tokens,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user