ugh..........
This commit is contained in:
parent
a65c8144f4
commit
285e493b12
|
@ -372,44 +372,20 @@ class AudioDecoder(nn.Module):
|
|||
self,
|
||||
levels,
|
||||
d_model,
|
||||
config_kwargs,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
training = config_kwargs.pop("training", False)
|
||||
attention_backend = config_kwargs.pop("attention_backend", "default")
|
||||
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
||||
hidden_size *= levels
|
||||
vocab_size *= levels
|
||||
|
||||
config_kwargs["hidden_size"] *= levels
|
||||
config_kwargs["vocab_size"] *= levels
|
||||
|
||||
hidden_size = config_kwargs.get("hidden_size")
|
||||
vocab_size = config_kwargs.get("vocab_size")
|
||||
|
||||
#self.d_model = d_model
|
||||
self.vocab_size = vocab_size
|
||||
self.up = nn.Linear( d_model, hidden_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size )
|
||||
self.transformer = None
|
||||
"""
|
||||
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
||||
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
if hasattr( self.transformer, "embeddings" ):
|
||||
del self.transformer.embeddings
|
||||
|
||||
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
|
||||
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
x = self.up( x )
|
||||
"""
|
||||
if self.transformer is not None:
|
||||
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
||||
"""
|
||||
x = self.down( x )
|
||||
|
||||
batch_size, seq_len, dim = x.shape
|
||||
|
@ -739,10 +715,7 @@ class Base(nn.Module):
|
|||
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
||||
self.monolithic_audio_encoder = False # monolithic sounds bad
|
||||
if self.version >= 7:
|
||||
pd_model = d_model // 4
|
||||
pd_ffn = pd_model * d_ffn
|
||||
pd_heads = n_heads // 4
|
||||
pd_layers = 1
|
||||
dec_dim = d_model * 4
|
||||
|
||||
if self.monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
|
@ -765,24 +738,8 @@ class Base(nn.Module):
|
|||
self.audio_decoder = AudioDecoder(
|
||||
self.n_resp_levels,
|
||||
d_model,
|
||||
dict(
|
||||
vocab_size=n_audio_tokens + 1,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
num_hidden_layers=pd_layers,
|
||||
num_attention_heads=pd_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=pd_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation="eager",
|
||||
|
||||
training=self.training,
|
||||
attention_backend=attention_backend,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
dec_dim,
|
||||
n_audio_tokens + 1,
|
||||
)
|
||||
|
||||
if attention_backend == "auto":
|
||||
|
|
Loading…
Reference in New Issue
Block a user