fix attention backend not being used

This commit is contained in:
mrq 2025-02-27 21:38:38 -06:00
parent b8e9f3d785
commit 0a45c9c042
3 changed files with 28 additions and 18 deletions

View File

@ -116,10 +116,11 @@ class RotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, config, layer_idx, mode = "default"):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.attn_mode = config.attn_mode
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
@ -128,7 +129,6 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.attn_mode = mode
if self.attn_mode == "math":
self.attn_mode = torch.nn.attention.SDPBackend.MATH
elif self.attn_mode == "mem_efficient":

View File

@ -456,7 +456,8 @@ class Base(nn.Module):
elif attention_backend == "fused_attn":
self.l_padding = 128
self.model = LlamaModel(LlamaConfig(
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
@ -468,9 +469,10 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
)
self.model_config.attn_mode = attention_backend
self.model = LlamaModel(self.model_config)
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)

View File

@ -88,10 +88,20 @@ class AudioEncoder(nn.Module):
n_tokens: int,
n_levels: int,
token_dim: int,
enc_mode: str = "sum"
):
super().__init__()
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
# self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
self.enc_mode = enc_mode
if enc_mode == "sum":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = None
elif enc_mode == "sub_interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
self.proj = None
elif enc_mode == "interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
@ -103,18 +113,15 @@ class AudioEncoder(nn.Module):
# old way
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
"""
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
"""
if self.enc_mode == "sum":
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# encode by interleaving embeddings into one "token"
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
"""
else:
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
if self.proj is not None:
x = self.proj(x)
"""
return x
@ -324,7 +331,7 @@ class Base_V2(nn.Module):
self.l_padding = 128
if self.arch_type in ["llama"]:
self.model = LlamaModel(LlamaConfig(
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
@ -336,9 +343,10 @@ class Base_V2(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
)
self.model_config.attn_mode = attention_backend
self.model = LlamaModel(self.model_config)
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(