diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index fa049fa..ad1eab7 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -845,15 +845,13 @@ def example_usage(): kwargs = { 'n_audio_tokens': cfg.model.audio_tokens, - 'd_model': 1024, # 256, # 1024, # 1536 - 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 12, # 32 - 'n_experts': 1 if not cfg.model else cfg.model.experts, - + 'd_model': cfg.model.dim, + 'd_ffn': cfg.model.ffn, + 'n_heads': cfg.model.heads, + 'n_layers': cfg.model.layers, + 'n_experts': cfg.model.experts, 'p_dropout': 0.1, - 'l_padding': 8 if cfg.optimizations.fp8 else 0, - 'config': cfg.model } diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index f185dea..a12d93e 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -90,10 +90,10 @@ class AudioEncoder(nn.Module): token_dim: int, ): super().__init__() - self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) - self.proj = nn.Linear(8 * token_dim, 1 * token_dim) + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)]) + # self.proj = nn.Linear(8 * token_dim, 1 * token_dim) - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: + def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: # empty if xi.shape[0] == 0: dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] @@ -102,15 +102,18 @@ class AudioEncoder(nn.Module): xi = _dropout_codes( xi, dropout_mask, dropout_token ) # old way - # this probably is a tried and true good way to go about it - x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) - - # encode by interleaving - # this "works" but I imagine it being excessive and doesn't seem to help the model all that much + # in theory RVQ-based codecs should prefer this, but this doesn't yield good results """ + x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) + """ + + # encode by interleaving embeddings into one "token" + # this "works" but I imagine it being excessive and doesn't seem to help the model all that much x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) x = x.view(x.shape[0], -1) - x = self.proj(x) + """ + if self.proj is not None: + x = self.proj(x) """ return x @@ -311,8 +314,8 @@ class Base_V2(nn.Module): self.n_resp_levels, ) self.len_decoder = AuxDecoder( d_model, 11 ) - self.text_decoder = AuxDecoder( d_model, n_phn_tokens ) - self.raw_text_decoder = AuxDecoder( d_model, n_text_tokens ) + self.phn_decoder = AuxDecoder( d_model, n_phn_tokens ) + self.text_decoder = AuxDecoder( d_model, n_text_tokens ) # override any requested padding size if attention_backend == "flash_attn_v100": @@ -400,6 +403,8 @@ class Base_V2(nn.Module): if self.n_experts > 1 and self.training: router_logits = output["router_logits"] aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m ) + else: + hidden_states = self.model(x) # process it into a format that I like if output_hidden_states: