diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b1bb6a0..3e39c61 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -372,44 +372,20 @@ class AudioDecoder(nn.Module): self, levels, d_model, - config_kwargs, + hidden_size, + vocab_size, ): super().__init__() - training = config_kwargs.pop("training", False) - attention_backend = config_kwargs.pop("attention_backend", "default") - gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True) + hidden_size *= levels + vocab_size *= levels - config_kwargs["hidden_size"] *= levels - config_kwargs["vocab_size"] *= levels - - hidden_size = config_kwargs.get("hidden_size") - vocab_size = config_kwargs.get("vocab_size") - - #self.d_model = d_model self.vocab_size = vocab_size self.up = nn.Linear( d_model, hidden_size ) self.down = nn.Linear( hidden_size, vocab_size ) - self.transformer = None - """ - self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs)) - self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) - - if hasattr( self.transformer, "embeddings" ): - del self.transformer.embeddings - - if gradient_checkpointing and not self.transformer.gradient_checkpointing: - self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( - use_reentrant=False - )) - """ def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: x = self.up( x ) - """ - if self.transformer is not None: - x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"] - """ x = self.down( x ) batch_size, seq_len, dim = x.shape @@ -739,10 +715,7 @@ class Base(nn.Module): self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way self.monolithic_audio_encoder = False # monolithic sounds bad if self.version >= 7: - pd_model = d_model // 4 - pd_ffn = pd_model * d_ffn - pd_heads = n_heads // 4 - pd_layers = 1 + dec_dim = d_model * 4 if self.monolithic_audio_encoder: self.audio_emb = AudioEncoder( @@ -765,24 +738,8 @@ class Base(nn.Module): self.audio_decoder = AudioDecoder( self.n_resp_levels, d_model, - dict( - vocab_size=n_audio_tokens + 1, - hidden_size=pd_model, - max_position_embeddings=max_position_embeddings, - intermediate_size=pd_ffn, - num_hidden_layers=pd_layers, - num_attention_heads=pd_heads, - attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=pd_heads, - hidden_act="gelu", - is_encoder_decoder=False, - is_decoder=True, - attn_implementation="eager", - - training=self.training, - attention_backend=attention_backend, - gradient_checkpointing=self.gradient_checkpointing, - ) + dec_dim, + n_audio_tokens + 1, ) if attention_backend == "auto":