diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 280b659..2453e8d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1423,7 +1423,7 @@ class Base(nn.Module): # calculate token probabilities scores = [ - [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] + F.softmax(logit, dim=-1).gather(1, tokens.unsqueeze(-1)).squeeze(-1) for logit, tokens in zip(logits, res) ] diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index fea0c0b..1c5d6b5 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -103,11 +103,16 @@ class FiniteAudioEncoder(nn.Module): self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02) self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity() - self.proj = nn.Sequential( - nn.Linear(token_dim, token_dim * d_ffn), - nn.GELU(), - nn.Linear(token_dim * d_ffn, d_model), - ) if use_ffn else nn.Linear(token_dim, d_model) + if use_ffn: + self.proj = nn.Sequential( + nn.Linear(token_dim, token_dim * d_ffn), + nn.GELU(), + nn.Linear(token_dim * d_ffn, d_model), + ) + elif token_dim != d_model: + self.proj = nn.Linear(token_dim, d_model) + else: + self.proj = nn.Identity() self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) @@ -122,7 +127,7 @@ class FiniteAudioEncoder(nn.Module): nn.init.zeros_(self.proj[0].bias) nn.init.zeros_(self.proj[2].bias) - else: + elif token_dim != d_model: nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias) @@ -1306,7 +1311,7 @@ class Base_V2(nn.Module): res = [ Categorical(logits=logit / temperature).sample() for logit in logits ] scores = [ - torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device) + torch.gather(prob, 2, tokens.unsqueeze(-1)).squeeze(-1) for prob, tokens in zip(probabilities, res) ]