diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 16c8b20..72d7f10 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -354,36 +354,55 @@ class AudioEncoder(nn.Module): self.proj = nn.Linear(8 * token_dim, 1 * token_dim) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: + # empty + if xi.shape[0] == 0: + return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device) if dropout_mask is not None: xi = _dropout_codes( xi, dropout_mask, dropout_token ) - x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1) - x = self.proj(x) + # old way """ x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) """ + # naive way to "encode" by flattening + """ + x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1) + x = self.proj(x) + """ + + # encode by interleaving + seq_len = xi.shape[0] + # (8, seq_len, dim) + x = [ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ] + # => (seq_len, dim * 8) interleaved + x_i = [] + for i in range(xi.shape[0]): + x_i.append(torch.cat([ x[l][i] for l in range(len(self.embs)) ], dim=-1)) + x = torch.stack( x_i, dim=0 ) + # => (seq_len, dim) + x = self.proj(x) return x -# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output -# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment class AudioDecoder(nn.Module): def __init__( self, d_model, hidden_size, vocab_size, + resp_levels, ): super().__init__() - self.vocab_size = vocab_size self.up = nn.Linear( d_model, hidden_size ) - self.down = nn.Linear( hidden_size, vocab_size ) + self.down = nn.Linear( hidden_size, vocab_size * resp_levels ) def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: x = self.up( x ) x = self.down( x ) + # interleave by reshaping / permuting + # at least I hope this does it properly batch_size, seq_len, dim = x.shape x = x.reshape( batch_size, seq_len, 8, dim // 8 ) x = x.permute( 0, 2, 1, 3 ) @@ -738,7 +757,8 @@ class Base(nn.Module): self.audio_decoder = AudioDecoder( d_model, d_model * 2, - (n_audio_tokens + 1) * self.n_resp_levels, + (n_audio_tokens + 1), + self.n_resp_levels, ) if attention_backend == "auto":