From 5120ffdda7395b1f873d50d272aa397ecefe76fe Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 29 Apr 2024 18:24:05 -0500 Subject: [PATCH] god it would be nice to know the best way to handle audio embeddings, because I genuinely don't know without skimming through papers or devoting X amount of GPU hours in training --- vall_e/config.py | 1 + vall_e/models/base.py | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 580fde6..b6cfe51 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -190,6 +190,7 @@ class Model: p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "eager" # or flash_attention_2 + audio_embedding_sums: bool = True def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index a953b53..f051e6d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -174,11 +174,8 @@ class Embedding(nn.Embedding): return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) """ +# Deprecated implementation class MultiEmbedding(nn.Module): - """ - This embedding sums embeddings on different levels. - """ - def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): super().__init__() self.monolithic = monolithic @@ -216,21 +213,41 @@ class MultiEmbedding(nn.Module): # Embedding that sums each RVQ-bin level within a given input acoustic prompt class AudioEmbedding(nn.Module): - def __init__(self, l_tokens, token_dim, levels=None): + def __init__( + self, + l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) + token_dim: int, # dimensionality of the embedding + levels: int | None = None, # number of RVQ-bins (I don't remember the specifics) + sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) + ): super().__init__() + # array of embeddings + # proms are [0, prom_levels] + # resp are split to where [0] is for the AR, and [1:] are reserved for NAR self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) + # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None + # + self.sums = sums def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: # prom if quant_levels is None and xi.shape[-1] > 1: - x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + if self.sums: + x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + else: + k = 0 # only use the most significant RVQ bin level for the input prom + x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) # AR resp elif quant_levels is None or quant_levels == 0: x = self.embeddings[0]( xi[:, 0] ) # NAR resp else: - x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + if self.sums: + x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + else: + k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding + x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) return x @@ -345,9 +362,17 @@ class Base(nn.Module): self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) else: # [1024] * 8 - self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None) + self.proms_emb = AudioEmbedding( + [n_prom_tokens] * self.n_prom_levels, d_model, + levels=self.n_prom_levels if self.version > 3 else None, + sums=self.config.audio_embedding_sums + ) # [1025] + [1024] * 8 - self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None) + self.resps_emb = AudioEmbedding( + [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, + levels=self.n_resp_levels if self.version > 3 else None, + sums=self.config.audio_embedding_sums + ) if self.version >= 3: