maybe this will work
This commit is contained in:
parent
01e96bafc9
commit
b8e9f3d785
|
@ -845,15 +845,13 @@ def example_usage():
|
|||
kwargs = {
|
||||
'n_audio_tokens': cfg.model.audio_tokens,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1 if not cfg.model else cfg.model.experts,
|
||||
|
||||
'd_model': cfg.model.dim,
|
||||
'd_ffn': cfg.model.ffn,
|
||||
'n_heads': cfg.model.heads,
|
||||
'n_layers': cfg.model.layers,
|
||||
'n_experts': cfg.model.experts,
|
||||
'p_dropout': 0.1,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
||||
'config': cfg.model
|
||||
}
|
||||
|
||||
|
|
|
@ -90,10 +90,10 @@ class AudioEncoder(nn.Module):
|
|||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
||||
# self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
|
@ -102,15 +102,18 @@ class AudioEncoder(nn.Module):
|
|||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
# old way
|
||||
# this probably is a tried and true good way to go about it
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
|
||||
# encode by interleaving
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
|
||||
"""
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
"""
|
||||
|
||||
# encode by interleaving embeddings into one "token"
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
if self.proj is not None:
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
return x
|
||||
|
@ -311,8 +314,8 @@ class Base_V2(nn.Module):
|
|||
self.n_resp_levels,
|
||||
)
|
||||
self.len_decoder = AuxDecoder( d_model, 11 )
|
||||
self.text_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||
self.raw_text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||
|
||||
# override any requested padding size
|
||||
if attention_backend == "flash_attn_v100":
|
||||
|
@ -400,6 +403,8 @@ class Base_V2(nn.Module):
|
|||
if self.n_experts > 1 and self.training:
|
||||
router_logits = output["router_logits"]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
|
||||
else:
|
||||
hidden_states = self.model(x)
|
||||
|
||||
# process it into a format that I like
|
||||
if output_hidden_states:
|
||||
|
|
Loading…
Reference in New Issue
Block a user