maybe this will work

This commit is contained in:
mrq 2025-02-27 20:42:12 -06:00
parent 01e96bafc9
commit b8e9f3d785
2 changed files with 21 additions and 18 deletions

View File

@ -845,15 +845,13 @@ def example_usage():
kwargs = {
'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1 if not cfg.model else cfg.model.experts,
'd_model': cfg.model.dim,
'd_ffn': cfg.model.ffn,
'n_heads': cfg.model.heads,
'n_layers': cfg.model.layers,
'n_experts': cfg.model.experts,
'p_dropout': 0.1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model
}

View File

@ -90,10 +90,10 @@ class AudioEncoder(nn.Module):
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
# self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
@ -102,15 +102,18 @@ class AudioEncoder(nn.Module):
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# old way
# this probably is a tried and true good way to go about it
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# encode by interleaving
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
"""
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
"""
# encode by interleaving embeddings into one "token"
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
x = self.proj(x)
"""
if self.proj is not None:
x = self.proj(x)
"""
return x
@ -311,8 +314,8 @@ class Base_V2(nn.Module):
self.n_resp_levels,
)
self.len_decoder = AuxDecoder( d_model, 11 )
self.text_decoder = AuxDecoder( d_model, n_phn_tokens )
self.raw_text_decoder = AuxDecoder( d_model, n_text_tokens )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
# override any requested padding size
if attention_backend == "flash_attn_v100":
@ -400,6 +403,8 @@ class Base_V2(nn.Module):
if self.n_experts > 1 and self.training:
router_logits = output["router_logits"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
else:
hidden_states = self.model(x)
# process it into a format that I like
if output_hidden_states: