add adapted MixtralAttention for when I make a bad decision to actually train a MoE

This commit is contained in:
mrq 2024-08-04 22:03:22 -05:00
parent 10aaf840e7
commit debcc93e7e
5 changed files with 207 additions and 8 deletions

View File

@ -375,7 +375,7 @@ def example_usage():
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'd_model': 256, # 256, # 1024, # 1536
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1 if not cfg.model else cfg.model.experts,
@ -468,8 +468,6 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer)
engines = Engines({"ar+nar": engine})
engines.setup()
print( model.state_dict().keys() )
"""
if cfg.optimizations.model_offloading:

View File

@ -44,7 +44,7 @@ except Exception as e:
pass
try:
from .mixtral import MixtralModel, MixtralConfig, load_balancing_loss_func
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, load_balancing_loss_func
AVAILABLE_ARCHES.append("mixtral")
except Exception as e:
ERROR_ARCHES["mixtral"] = e

View File

@ -131,6 +131,8 @@ class LlamaAttention_Adapted(LlamaAttention):
is_causal=is_causal,
)
print("attention")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

View File

@ -2,9 +2,11 @@
import torch
import torch.nn.functional as F
from typing import Literal, overload, Optional, Tuple
from transformers.cache_utils import Cache
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, apply_rotary_pos_emb, repeat_kv
# This is required because batch sizes > 1 throws errors
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -42,4 +44,197 @@ def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Te
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward
class MixtralAttention_Adapted(MixtralAttention):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "math"
if self.mode == "math":
self.mode = torch.nn.attention.SDPBackend.MATH
elif self.mode == "mem_efficient":
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
elif self.mode == "flash":
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
super().__init__(*args, **kwargs)
# Adapted from MixtralAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
"""
logger.warning_once(
"MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
"""
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
"""

View File

@ -582,6 +582,8 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
@ -605,6 +607,8 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
@ -625,6 +629,8 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
@ -753,8 +759,6 @@ class Base(nn.Module):
if hasattr( self.model, "embeddings" ):
del self.model.embeddings
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_resp_tokens)