yuge speedup because of a dumb oversight

This commit is contained in:
mrq 2025-03-20 17:39:41 -05:00
parent 8068f24e35
commit 589cfb0e18
2 changed files with 13 additions and 8 deletions

View File

@ -1423,7 +1423,7 @@ class Base(nn.Module):
# calculate token probabilities
scores = [
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
F.softmax(logit, dim=-1).gather(1, tokens.unsqueeze(-1)).squeeze(-1)
for logit, tokens in zip(logits, res)
]

View File

@ -103,11 +103,16 @@ class FiniteAudioEncoder(nn.Module):
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * d_ffn),
nn.GELU(),
nn.Linear(token_dim * d_ffn, d_model),
) if use_ffn else nn.Linear(token_dim, d_model)
if use_ffn:
self.proj = nn.Sequential(
nn.Linear(token_dim, token_dim * d_ffn),
nn.GELU(),
nn.Linear(token_dim * d_ffn, d_model),
)
elif token_dim != d_model:
self.proj = nn.Linear(token_dim, d_model)
else:
self.proj = nn.Identity()
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
@ -122,7 +127,7 @@ class FiniteAudioEncoder(nn.Module):
nn.init.zeros_(self.proj[0].bias)
nn.init.zeros_(self.proj[2].bias)
else:
elif token_dim != d_model:
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
@ -1306,7 +1311,7 @@ class Base_V2(nn.Module):
res = [ Categorical(logits=logit / temperature).sample() for logit in logits ]
scores = [
torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device)
torch.gather(prob, 2, tokens.unsqueeze(-1)).squeeze(-1)
for prob, tokens in zip(probabilities, res)
]