fix attention backend not being used
This commit is contained in:
parent
b8e9f3d785
commit
0a45c9c042
|
@ -116,10 +116,11 @@ class RotaryEmbedding(nn.Module):
|
|||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config, layer_idx, mode = "default"):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.attn_mode = config.attn_mode
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
|
@ -128,7 +129,6 @@ class Attention(nn.Module):
|
|||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
|
||||
self.attn_mode = mode
|
||||
if self.attn_mode == "math":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
||||
elif self.attn_mode == "mem_efficient":
|
||||
|
|
|
@ -456,7 +456,8 @@ class Base(nn.Module):
|
|||
elif attention_backend == "fused_attn":
|
||||
self.l_padding = 128
|
||||
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
|
@ -468,9 +469,10 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
)
|
||||
self.model_config.attn_mode = attention_backend
|
||||
self.model = LlamaModel(self.model_config)
|
||||
|
||||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
|
||||
|
|
|
@ -88,10 +88,20 @@ class AudioEncoder(nn.Module):
|
|||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
enc_mode: str = "sum"
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
||||
# self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
self.enc_mode = enc_mode
|
||||
|
||||
if enc_mode == "sum":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = None
|
||||
elif enc_mode == "sub_interleave":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
||||
self.proj = None
|
||||
elif enc_mode == "interleave":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
|
@ -103,18 +113,15 @@ class AudioEncoder(nn.Module):
|
|||
|
||||
# old way
|
||||
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
|
||||
"""
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
"""
|
||||
|
||||
if self.enc_mode == "sum":
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
# encode by interleaving embeddings into one "token"
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
x = x.view(x.shape[0], -1)
|
||||
"""
|
||||
else:
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
x = x.view(x.shape[0], -1)
|
||||
if self.proj is not None:
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
return x
|
||||
|
||||
|
@ -324,7 +331,7 @@ class Base_V2(nn.Module):
|
|||
self.l_padding = 128
|
||||
|
||||
if self.arch_type in ["llama"]:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
self.model_config = LlamaConfig(
|
||||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
|
@ -336,9 +343,10 @@ class Base_V2(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
)
|
||||
self.model_config.attn_mode = attention_backend
|
||||
self.model = LlamaModel(self.model_config)
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
|
|
Loading…
Reference in New Issue
Block a user