From 3dca1125f5c84b16d14c3607e699694675a1b9e1 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 4 May 2024 13:07:45 -0500 Subject: [PATCH] implemented xformers in HF's Llama (because theres no flash attention for Volta cards) --- vall_e/models/base.py | 116 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 214d41e..23d2a71 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -5,7 +5,7 @@ import traceback import numpy as np import re -from typing import Literal, overload +from typing import Literal, overload, Optional, Tuple from functools import partial from einops import rearrange @@ -132,6 +132,107 @@ try: except Exception as e: print("Error importing `mixtral` arch:", e) + +LLAMA_ATTENTIONS = {} +try: + from transformers.cache_utils import Cache + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + + from xformers.ops import LowerTriangularMask + from xformers.ops.fmha import memory_efficient_attention + + class LLamaXformersAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # copied from https://github.com/oobabooga/text-generation-webui/pull/950/files + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + else: + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) + + attn_weights = None + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention + +except Exception as e: + print("Error creating `LLamaXformersAttention`:", e) + +def replace_attention( model, impl, verbose=True ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)] + + if impl not in LLAMA_ATTENTIONS: + print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'") + return model + + klass = LLAMA_ATTENTIONS[impl] + + for *parent, k in attentions: + name = '.'.join(parent) + + # copy parameters + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + config = m.config + layer_idx = m.layer_idx + + kwargs = dict(config=config, layer_idx=layer_idx) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -431,6 +532,14 @@ class Base(nn.Module): )) elif self.arch_type == "llama": if n_experts <= 1: + # ick, there has to be a better way + attention = self.config.attention if self.config is not None else None # "flash_attention_2", + use_xformers = False + + if attention == "xformers": + use_xformers = True + attention = None + self.model = LlamaModel(LlamaConfig( vocab_size=n_resp_tokens, hidden_size=d_model, @@ -444,8 +553,11 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", + attn_implementation=attention, )) + + if use_xformers: + self.model = replace_attention( self.model, "xformers" if use_xformers else attention ) else: self.model = MixtralModel(MixtralConfig( vocab_size =n_resp_tokens,