its a miracle i was able to get a semblance of audio with the naive AudioEncoder (now it interleaves properly)
This commit is contained in:
parent
6e7b269147
commit
33d5a7109a
|
@ -354,36 +354,55 @@ class AudioEncoder(nn.Module):
|
||||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||||
|
|
||||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||||
|
# empty
|
||||||
|
if xi.shape[0] == 0:
|
||||||
|
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device)
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||||
|
|
||||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
# old way
|
||||||
x = self.proj(x)
|
|
||||||
"""
|
"""
|
||||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||||
"""
|
"""
|
||||||
|
# naive way to "encode" by flattening
|
||||||
|
"""
|
||||||
|
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||||
|
x = self.proj(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# encode by interleaving
|
||||||
|
seq_len = xi.shape[0]
|
||||||
|
# (8, seq_len, dim)
|
||||||
|
x = [ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]
|
||||||
|
# => (seq_len, dim * 8) interleaved
|
||||||
|
x_i = []
|
||||||
|
for i in range(xi.shape[0]):
|
||||||
|
x_i.append(torch.cat([ x[l][i] for l in range(len(self.embs)) ], dim=-1))
|
||||||
|
x = torch.stack( x_i, dim=0 )
|
||||||
|
# => (seq_len, dim)
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
|
||||||
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
|
||||||
class AudioDecoder(nn.Module):
|
class AudioDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model,
|
d_model,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
resp_levels,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.up = nn.Linear( d_model, hidden_size )
|
self.up = nn.Linear( d_model, hidden_size )
|
||||||
self.down = nn.Linear( hidden_size, vocab_size )
|
self.down = nn.Linear( hidden_size, vocab_size * resp_levels )
|
||||||
|
|
||||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||||
x = self.up( x )
|
x = self.up( x )
|
||||||
x = self.down( x )
|
x = self.down( x )
|
||||||
|
|
||||||
|
# interleave by reshaping / permuting
|
||||||
|
# at least I hope this does it properly
|
||||||
batch_size, seq_len, dim = x.shape
|
batch_size, seq_len, dim = x.shape
|
||||||
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
||||||
x = x.permute( 0, 2, 1, 3 )
|
x = x.permute( 0, 2, 1, 3 )
|
||||||
|
@ -738,7 +757,8 @@ class Base(nn.Module):
|
||||||
self.audio_decoder = AudioDecoder(
|
self.audio_decoder = AudioDecoder(
|
||||||
d_model,
|
d_model,
|
||||||
d_model * 2,
|
d_model * 2,
|
||||||
(n_audio_tokens + 1) * self.n_resp_levels,
|
(n_audio_tokens + 1),
|
||||||
|
self.n_resp_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user