From 0a45c9c04285671682da7e34af16df098d32b4f8 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 27 Feb 2025 21:38:38 -0600 Subject: [PATCH] fix attention backend not being used --- vall_e/models/arch/llama.py | 4 ++-- vall_e/models/base.py | 8 +++++--- vall_e/models/base_v2.py | 34 +++++++++++++++++++++------------- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 95d9617..f0ad52d 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -116,10 +116,11 @@ class RotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Attention(nn.Module): - def __init__(self, config, layer_idx, mode = "default"): + def __init__(self, config, layer_idx): super().__init__() self.config = config + self.attn_mode = config.attn_mode self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads @@ -128,7 +129,6 @@ class Attention(nn.Module): self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.attn_mode = mode if self.attn_mode == "math": self.attn_mode = torch.nn.attention.SDPBackend.MATH elif self.attn_mode == "mem_efficient": diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b81a6c9..60c0623 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -456,7 +456,8 @@ class Base(nn.Module): elif attention_backend == "fused_attn": self.l_padding = 128 - self.model = LlamaModel(LlamaConfig( + + self.model_config = LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, @@ -468,9 +469,10 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, - )) + ) + self.model_config.attn_mode = attention_backend + self.model = LlamaModel(self.model_config) if not split_classifiers: self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index a12d93e..28e52c8 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -88,10 +88,20 @@ class AudioEncoder(nn.Module): n_tokens: int, n_levels: int, token_dim: int, + enc_mode: str = "sum" ): super().__init__() - self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)]) - # self.proj = nn.Linear(8 * token_dim, 1 * token_dim) + self.enc_mode = enc_mode + + if enc_mode == "sum": + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.proj = None + elif enc_mode == "sub_interleave": + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)]) + self.proj = None + elif enc_mode == "interleave": + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.proj = nn.Linear(8 * token_dim, 1 * token_dim) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: # empty @@ -103,18 +113,15 @@ class AudioEncoder(nn.Module): # old way # in theory RVQ-based codecs should prefer this, but this doesn't yield good results - """ - x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) - """ - + if self.enc_mode == "sum": + x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) # encode by interleaving embeddings into one "token" # this "works" but I imagine it being excessive and doesn't seem to help the model all that much - x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) - x = x.view(x.shape[0], -1) - """ + else: + x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) + x = x.view(x.shape[0], -1) if self.proj is not None: x = self.proj(x) - """ return x @@ -324,7 +331,7 @@ class Base_V2(nn.Module): self.l_padding = 128 if self.arch_type in ["llama"]: - self.model = LlamaModel(LlamaConfig( + self.model_config = LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, @@ -336,9 +343,10 @@ class Base_V2(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, - )) + ) + self.model_config.attn_mode = attention_backend + self.model = LlamaModel(self.model_config) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(