its a miracle i was able to get a semblance of audio with the naive AudioEncoder (now it interleaves properly)
This commit is contained in:
parent
6e7b269147
commit
33d5a7109a
|
@ -354,36 +354,55 @@ class AudioEncoder(nn.Module):
|
|||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device)
|
||||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||
x = self.proj(x)
|
||||
# old way
|
||||
"""
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
"""
|
||||
# naive way to "encode" by flattening
|
||||
"""
|
||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
# encode by interleaving
|
||||
seq_len = xi.shape[0]
|
||||
# (8, seq_len, dim)
|
||||
x = [ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]
|
||||
# => (seq_len, dim * 8) interleaved
|
||||
x_i = []
|
||||
for i in range(xi.shape[0]):
|
||||
x_i.append(torch.cat([ x[l][i] for l in range(len(self.embs)) ], dim=-1))
|
||||
x = torch.stack( x_i, dim=0 )
|
||||
# => (seq_len, dim)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
||||
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
||||
class AudioDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
resp_levels,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.up = nn.Linear( d_model, hidden_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size * resp_levels )
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
x = self.up( x )
|
||||
x = self.down( x )
|
||||
|
||||
# interleave by reshaping / permuting
|
||||
# at least I hope this does it properly
|
||||
batch_size, seq_len, dim = x.shape
|
||||
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
||||
x = x.permute( 0, 2, 1, 3 )
|
||||
|
@ -738,7 +757,8 @@ class Base(nn.Module):
|
|||
self.audio_decoder = AudioDecoder(
|
||||
d_model,
|
||||
d_model * 2,
|
||||
(n_audio_tokens + 1) * self.n_resp_levels,
|
||||
(n_audio_tokens + 1),
|
||||
self.n_resp_levels,
|
||||
)
|
||||
|
||||
if attention_backend == "auto":
|
||||
|
|
Loading…
Reference in New Issue
Block a user