ugh..........
This commit is contained in:
parent
a65c8144f4
commit
285e493b12
|
@ -372,44 +372,20 @@ class AudioDecoder(nn.Module):
|
||||||
self,
|
self,
|
||||||
levels,
|
levels,
|
||||||
d_model,
|
d_model,
|
||||||
config_kwargs,
|
hidden_size,
|
||||||
|
vocab_size,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
training = config_kwargs.pop("training", False)
|
hidden_size *= levels
|
||||||
attention_backend = config_kwargs.pop("attention_backend", "default")
|
vocab_size *= levels
|
||||||
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
|
||||||
|
|
||||||
config_kwargs["hidden_size"] *= levels
|
|
||||||
config_kwargs["vocab_size"] *= levels
|
|
||||||
|
|
||||||
hidden_size = config_kwargs.get("hidden_size")
|
|
||||||
vocab_size = config_kwargs.get("vocab_size")
|
|
||||||
|
|
||||||
#self.d_model = d_model
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.up = nn.Linear( d_model, hidden_size )
|
self.up = nn.Linear( d_model, hidden_size )
|
||||||
self.down = nn.Linear( hidden_size, vocab_size )
|
self.down = nn.Linear( hidden_size, vocab_size )
|
||||||
self.transformer = None
|
|
||||||
"""
|
|
||||||
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
|
||||||
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
|
||||||
|
|
||||||
if hasattr( self.transformer, "embeddings" ):
|
|
||||||
del self.transformer.embeddings
|
|
||||||
|
|
||||||
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
|
|
||||||
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
|
||||||
use_reentrant=False
|
|
||||||
))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||||
x = self.up( x )
|
x = self.up( x )
|
||||||
"""
|
|
||||||
if self.transformer is not None:
|
|
||||||
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
|
||||||
"""
|
|
||||||
x = self.down( x )
|
x = self.down( x )
|
||||||
|
|
||||||
batch_size, seq_len, dim = x.shape
|
batch_size, seq_len, dim = x.shape
|
||||||
|
@ -739,10 +715,7 @@ class Base(nn.Module):
|
||||||
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
||||||
self.monolithic_audio_encoder = False # monolithic sounds bad
|
self.monolithic_audio_encoder = False # monolithic sounds bad
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
pd_model = d_model // 4
|
dec_dim = d_model * 4
|
||||||
pd_ffn = pd_model * d_ffn
|
|
||||||
pd_heads = n_heads // 4
|
|
||||||
pd_layers = 1
|
|
||||||
|
|
||||||
if self.monolithic_audio_encoder:
|
if self.monolithic_audio_encoder:
|
||||||
self.audio_emb = AudioEncoder(
|
self.audio_emb = AudioEncoder(
|
||||||
|
@ -765,24 +738,8 @@ class Base(nn.Module):
|
||||||
self.audio_decoder = AudioDecoder(
|
self.audio_decoder = AudioDecoder(
|
||||||
self.n_resp_levels,
|
self.n_resp_levels,
|
||||||
d_model,
|
d_model,
|
||||||
dict(
|
dec_dim,
|
||||||
vocab_size=n_audio_tokens + 1,
|
n_audio_tokens + 1,
|
||||||
hidden_size=pd_model,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
intermediate_size=pd_ffn,
|
|
||||||
num_hidden_layers=pd_layers,
|
|
||||||
num_attention_heads=pd_heads,
|
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
|
||||||
num_key_value_heads=pd_heads,
|
|
||||||
hidden_act="gelu",
|
|
||||||
is_encoder_decoder=False,
|
|
||||||
is_decoder=True,
|
|
||||||
attn_implementation="eager",
|
|
||||||
|
|
||||||
training=self.training,
|
|
||||||
attention_backend=attention_backend,
|
|
||||||
gradient_checkpointing=self.gradient_checkpointing,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user