From ec87308d756ae4cdd119799f4166c98f2184d81e Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Mar 2025 15:31:15 -0600 Subject: [PATCH] final tweaks before training this meme 44khz model for the 3rd time --- vall_e/models/base_v2.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 26486e3..b586623 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -153,13 +153,37 @@ class FiniteAudioEncoder(nn.Module): n_tokens: int, n_levels: int, token_dim: int, + use_ln: bool = True, + use_ffn: bool = True, training: bool = True, ): super().__init__() self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) - self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) - self.proj = nn.Linear(token_dim, token_dim) - self.level_weights = nn.Parameter(torch.ones(n_levels)) + self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02) + self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity() + self.proj = nn.Sequential( + nn.Linear(token_dim, token_dim * 2), + nn.GELU(), + nn.Linear(token_dim * 2, token_dim), + #nn.Dropout(0.1 if training else 0.0) + ) if use_ffn else nn.Linear(token_dim, token_dim) + + self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) + + # explicit initialization + for emb in self.embs: + torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02) + + self.use_ffn = use_ffn + if use_ffn: + nn.init.xavier_uniform_(self.proj[0].weight) + nn.init.xavier_uniform_(self.proj[2].weight) + + nn.init.zeros_(self.proj[0].bias) + nn.init.zeros_(self.proj[2].bias) + else: + nn.init.xavier_uniform_(self.proj.weight) + nn.init.zeros_(self.proj.bias) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: # empty @@ -171,8 +195,12 @@ class FiniteAudioEncoder(nn.Module): x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) x = x + self.pos_embedding - x = self.proj( x ) - + x = self.norm(x) + if self.use_ffn: + x = x + self.proj( x ) + else: + x = self.proj( x ) + weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1) x = (x * weights).sum(dim=1)