ugh..........

This commit is contained in:
mrq 2025-02-14 16:24:34 -06:00
parent a65c8144f4
commit 285e493b12

View File

@ -372,44 +372,20 @@ class AudioDecoder(nn.Module):
self,
levels,
d_model,
config_kwargs,
hidden_size,
vocab_size,
):
super().__init__()
training = config_kwargs.pop("training", False)
attention_backend = config_kwargs.pop("attention_backend", "default")
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
hidden_size *= levels
vocab_size *= levels
config_kwargs["hidden_size"] *= levels
config_kwargs["vocab_size"] *= levels
hidden_size = config_kwargs.get("hidden_size")
vocab_size = config_kwargs.get("vocab_size")
#self.d_model = d_model
self.vocab_size = vocab_size
self.up = nn.Linear( d_model, hidden_size )
self.down = nn.Linear( hidden_size, vocab_size )
self.transformer = None
"""
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if hasattr( self.transformer, "embeddings" ):
del self.transformer.embeddings
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
"""
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
x = self.up( x )
"""
if self.transformer is not None:
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
"""
x = self.down( x )
batch_size, seq_len, dim = x.shape
@ -739,10 +715,7 @@ class Base(nn.Module):
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
self.monolithic_audio_encoder = False # monolithic sounds bad
if self.version >= 7:
pd_model = d_model // 4
pd_ffn = pd_model * d_ffn
pd_heads = n_heads // 4
pd_layers = 1
dec_dim = d_model * 4
if self.monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
@ -765,24 +738,8 @@ class Base(nn.Module):
self.audio_decoder = AudioDecoder(
self.n_resp_levels,
d_model,
dict(
vocab_size=n_audio_tokens + 1,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
num_hidden_layers=pd_layers,
num_attention_heads=pd_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=pd_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation="eager",
training=self.training,
attention_backend=attention_backend,
gradient_checkpointing=self.gradient_checkpointing,
)
dec_dim,
n_audio_tokens + 1,
)
if attention_backend == "auto":