From e7996657597b2175944c36e7d6ba4bbcdbe130a1 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 25 Jan 2024 12:18:48 -0600 Subject: [PATCH] experimental weighting of prom/resp embeds --- vall_e/models/ar_nar.py | 4 ++-- vall_e/models/base.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4207a00..05808e5 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -94,7 +94,7 @@ class AR_NAR(Base): lang_list: list[Tensor] | None = None, max_steps: int = 1000, - max_levels: int = 7, + max_levels: int = 0, max_resp_context: int = -1, sampling_temperature: float = 1.0, @@ -166,7 +166,7 @@ class AR_NAR(Base): ) # is NAR if max_levels == 0: - max_levels = self.n_resp_levels + max_levels = self.n_resp_levels - 1 prev_list = resps_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 848c6b9..f678a2d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -123,9 +123,10 @@ class MultiEmbedding(nn.Module): # Embedding that sums each RVQ-bin level within a given input acoustic prompt class AudioEmbedding(nn.Module): - def __init__(self, l_tokens, token_dim): + def __init__(self, l_tokens, token_dim, levels=None): super().__init__() self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) + self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: res_list = [] @@ -133,13 +134,13 @@ class AudioEmbedding(nn.Module): for i, xi in enumerate(x_list): # prom if quant_levels is None and xi.shape[-1] > 1: - x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] ) + x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) # AR resp elif quant_levels is None or quant_levels[i] == 0: x = self.embeddings[0]( xi[:, 0] ) # NAR resp else: - x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] ) + x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) res_list.append(x) return res_list @@ -255,9 +256,9 @@ class Base(nn.Module): self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) else: # [1024] * 8 - self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model) + self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model, self.n_prom_levels if self.version > 3 else None) # [1025] + [1024] * 8 - self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model) + self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, self.n_resp_levels if self.version > 3 else None) if self.version >= 3: