its a miracle i was able to get a semblance of audio with the naive AudioEncoder (now it interleaves properly)

This commit is contained in:
mrq 2025-02-24 14:39:12 -06:00
parent 6e7b269147
commit 33d5a7109a

View File

@ -354,36 +354,55 @@ class AudioEncoder(nn.Module):
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
x = self.proj(x)
# old way
"""
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
"""
# naive way to "encode" by flattening
"""
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
x = self.proj(x)
"""
# encode by interleaving
seq_len = xi.shape[0]
# (8, seq_len, dim)
x = [ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]
# => (seq_len, dim * 8) interleaved
x_i = []
for i in range(xi.shape[0]):
x_i.append(torch.cat([ x[l][i] for l in range(len(self.embs)) ], dim=-1))
x = torch.stack( x_i, dim=0 )
# => (seq_len, dim)
x = self.proj(x)
return x
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
class AudioDecoder(nn.Module):
def __init__(
self,
d_model,
hidden_size,
vocab_size,
resp_levels,
):
super().__init__()
self.vocab_size = vocab_size
self.up = nn.Linear( d_model, hidden_size )
self.down = nn.Linear( hidden_size, vocab_size )
self.down = nn.Linear( hidden_size, vocab_size * resp_levels )
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
x = self.up( x )
x = self.down( x )
# interleave by reshaping / permuting
# at least I hope this does it properly
batch_size, seq_len, dim = x.shape
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
x = x.permute( 0, 2, 1, 3 )
@ -738,7 +757,8 @@ class Base(nn.Module):
self.audio_decoder = AudioDecoder(
d_model,
d_model * 2,
(n_audio_tokens + 1) * self.n_resp_levels,
(n_audio_tokens + 1),
self.n_resp_levels,
)
if attention_backend == "auto":