add adapted MixtralAttention for when I make a bad decision to actually train a MoE

This commit is contained in:
mrq 2024-08-04 22:02:59 -05:00
parent dc9966b0fd
commit c944938d27

View File

@ -582,6 +582,8 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(