From 516b0894d7c813276013c59fe1ff85a2dc819ffb Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Jun 2024 19:41:26 -0500 Subject: [PATCH] m --- vall_e/models/base.py | 42 ++++++++++++------------------------------ 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 434ffd6..43a794d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -105,9 +105,7 @@ class AudioEmbedding_Old(nn.Module): self, l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding - mode: "old", # old | prom | resp levels: int | None = None, # number of RVQ-bins (I don't remember the specifics) - sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) ): super().__init__() # array of embeddings @@ -115,31 +113,19 @@ class AudioEmbedding_Old(nn.Module): # resp are split to where [0] is for the AR, and [1:] are reserved for NAR self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) - self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None - # - self.mode = mode - # - self.sums = sums + self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: # prom if quant_levels is None and xi.shape[-1] > 1: - if self.sums: - x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) - else: - k = 0 # only use the most significant RVQ bin level for the input prom - x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) + x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) # AR resp elif quant_levels is None or quant_levels == 0: x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] ) # NAR resp else: - if self.sums: - x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) - else: - k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding - x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) - + x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + return x class AudioEmbedding(nn.Module): @@ -165,16 +151,16 @@ class AudioEmbedding(nn.Module): if quant_level is None: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 - # embeddings for AR/NAR cannot be shared - offset = 0 if self.mode == "prom" or quant_level == 0 else 1 - + # jank but needed if xi.dim() == 1: - x = self.embeddings[quant_level]( xi ) - elif self.sums and quant_level > 0: + return self.embeddings[quant_level]( xi ) + + offset = 0 if self.mode == "prom" else 1 + if self.sums and quant_level > 0: x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) else: - k = quant_level - x = self.embeddings[k + offset]( xi[:, k] ) + k = quant_level - 1 + x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] ) return x @@ -302,15 +288,11 @@ class Base(nn.Module): self.proms_emb = AudioEmbedding_Old( [n_prom_tokens] * self.n_prom_levels, d_model, levels=self.n_prom_levels if self.version > 3 else None, - mode="prom" if self.version >= 5 else "old", - sums=self.config.audio_embedding_sums if self.config is not None else True, ) # [1024 + STOP] + [1024] * 8 self.resps_emb = AudioEmbedding_Old( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, levels=self.n_resp_levels if self.version > 3 else None, - mode="resp" if self.version >= 5 else "old", - sums=self.config.audio_embedding_sums if self.config is not None else True ) else: self.proms_emb = AudioEmbedding( @@ -686,7 +668,7 @@ class Base(nn.Module): elif name == "lang" and self.langs_emb is not None: embedding = self.langs_emb( input ) elif name == "prom": - embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] ) + embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None ) elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp":