yuge speedup because of a dumb oversight
This commit is contained in:
parent
8068f24e35
commit
589cfb0e18
|
@ -1423,7 +1423,7 @@ class Base(nn.Module):
|
|||
|
||||
# calculate token probabilities
|
||||
scores = [
|
||||
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
|
||||
F.softmax(logit, dim=-1).gather(1, tokens.unsqueeze(-1)).squeeze(-1)
|
||||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
|
||||
|
|
|
@ -103,11 +103,16 @@ class FiniteAudioEncoder(nn.Module):
|
|||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
||||
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(token_dim, token_dim * d_ffn),
|
||||
nn.GELU(),
|
||||
nn.Linear(token_dim * d_ffn, d_model),
|
||||
) if use_ffn else nn.Linear(token_dim, d_model)
|
||||
if use_ffn:
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(token_dim, token_dim * d_ffn),
|
||||
nn.GELU(),
|
||||
nn.Linear(token_dim * d_ffn, d_model),
|
||||
)
|
||||
elif token_dim != d_model:
|
||||
self.proj = nn.Linear(token_dim, d_model)
|
||||
else:
|
||||
self.proj = nn.Identity()
|
||||
|
||||
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels))
|
||||
|
||||
|
@ -122,7 +127,7 @@ class FiniteAudioEncoder(nn.Module):
|
|||
|
||||
nn.init.zeros_(self.proj[0].bias)
|
||||
nn.init.zeros_(self.proj[2].bias)
|
||||
else:
|
||||
elif token_dim != d_model:
|
||||
nn.init.xavier_uniform_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
|
@ -1306,7 +1311,7 @@ class Base_V2(nn.Module):
|
|||
res = [ Categorical(logits=logit / temperature).sample() for logit in logits ]
|
||||
|
||||
scores = [
|
||||
torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device)
|
||||
torch.gather(prob, 2, tokens.unsqueeze(-1)).squeeze(-1)
|
||||
for prob, tokens in zip(probabilities, res)
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user