final tweaks before training this meme 44khz model for the 3rd time
This commit is contained in:
parent
5cd71ef238
commit
ec87308d75
|
@ -153,13 +153,37 @@ class FiniteAudioEncoder(nn.Module):
|
|||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
use_ln: bool = True,
|
||||
use_ffn: bool = True,
|
||||
training: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim))
|
||||
self.proj = nn.Linear(token_dim, token_dim)
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
||||
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(token_dim, token_dim * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(token_dim * 2, token_dim),
|
||||
#nn.Dropout(0.1 if training else 0.0)
|
||||
) if use_ffn else nn.Linear(token_dim, token_dim)
|
||||
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
||||
|
||||
# explicit initialization
|
||||
for emb in self.embs:
|
||||
torch.nn.init.normal_(emb.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.use_ffn = use_ffn
|
||||
if use_ffn:
|
||||
nn.init.xavier_uniform_(self.proj[0].weight)
|
||||
nn.init.xavier_uniform_(self.proj[2].weight)
|
||||
|
||||
nn.init.zeros_(self.proj[0].bias)
|
||||
nn.init.zeros_(self.proj[2].bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
|
@ -171,6 +195,10 @@ class FiniteAudioEncoder(nn.Module):
|
|||
|
||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
x = self.norm(x)
|
||||
if self.use_ffn:
|
||||
x = x + self.proj( x )
|
||||
else:
|
||||
x = self.proj( x )
|
||||
|
||||
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user