2024-06-06 01:30:43 +00:00
|
|
|
# https://github.com/state-spaces/mamba
|
2024-06-06 14:48:43 +00:00
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
2024-06-06 01:30:43 +00:00
|
|
|
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
|
|
|
|
|
|
|
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
|
|
|
if hidden_states is None:
|
|
|
|
hidden_states = self.embedding(input_ids)
|
|
|
|
residual = None
|
|
|
|
for layer in self.layers:
|
|
|
|
if self.gradient_checkpointing and hidden_states.requires_grad:
|
2024-08-10 02:15:01 +00:00
|
|
|
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
|
2024-06-06 01:30:43 +00:00
|
|
|
else:
|
2024-08-10 02:15:01 +00:00
|
|
|
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
|
2024-06-06 01:30:43 +00:00
|
|
|
if not self.fused_add_norm:
|
|
|
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
|
|
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
|
|
|
else:
|
|
|
|
# Set prenorm=False here since we don't need the residual
|
|
|
|
hidden_states = MambaLayerNormFn(
|
|
|
|
hidden_states,
|
|
|
|
self.norm_f.weight,
|
|
|
|
self.norm_f.bias,
|
|
|
|
eps=self.norm_f.eps,
|
|
|
|
residual=residual,
|
|
|
|
prenorm=False,
|
|
|
|
residual_in_fp32=self.residual_in_fp32,
|
|
|
|
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
|
|
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
MambaMixelModel.forward = MambaMixelModel_forward
|