final tweaks before training this meme 44khz model for the 3rd time

This commit is contained in:
mrq 2025-03-06 15:31:15 -06:00
parent 5cd71ef238
commit ec87308d75

View File

@ -153,13 +153,37 @@ class FiniteAudioEncoder(nn.Module):
n_tokens: int, n_tokens: int,
n_levels: int, n_levels: int,
token_dim: int, token_dim: int,
use_ln: bool = True,
use_ffn: bool = True,
training: bool = True, training: bool = True,
): ):
super().__init__() super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
self.proj = nn.Linear(token_dim, token_dim) self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
self.level_weights = nn.Parameter(torch.ones(n_levels)) self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * 2),
nn.GELU(),
nn.Linear(token_dim * 2, token_dim),
#nn.Dropout(0.1 if training else 0.0)
) if use_ffn else nn.Linear(token_dim, token_dim)
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
# explicit initialization
for emb in self.embs:
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
self.use_ffn = use_ffn
if use_ffn:
nn.init.xavier_uniform_(self.proj[0].weight)
nn.init.xavier_uniform_(self.proj[2].weight)
nn.init.zeros_(self.proj[0].bias)
nn.init.zeros_(self.proj[2].bias)
else:
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty # empty
@ -171,6 +195,10 @@ class FiniteAudioEncoder(nn.Module):
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding x = x + self.pos_embedding
x = self.norm(x)
if self.use_ffn:
x = x + self.proj( x )
else:
x = self.proj( x ) x = self.proj( x )
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1) weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)