From f69aad9c65fae58efaffb8f4ac09a3f64824eb6a Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 8 Sep 2023 15:36:26 -0500 Subject: [PATCH] some day I'll get it right --- vall_e/config.py | 1 - vall_e/models/ar_nar.py | 2 +- vall_e/models/base.py | 13 ++++++------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index f640c19..3f8c11a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -163,7 +163,6 @@ class Model: arch_type: str = "transformer" training: bool = True interleave: bool = False - use_multiembedding: bool = True # nasty bandaid I got myself into frozen_params: list[str] = field(default_factory=lambda: []) @property diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 5d93c2b..8371950 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -107,7 +107,7 @@ class AR_NAR(Base): prev_list = resps_list while True: - level = prev_list[0].shape[-1] - 1 + level = prev_list[0].shape[-1] if level >= self.n_resp_levels: break diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 4645eaa..0d7c6df 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -64,13 +64,10 @@ class MultiEmbedding(nn.Embedding): def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): super().__init__(max_n_levels, token_dim) + self.monolithic = monolithic self.max_n_levels = max_n_levels self.n_tokens = n_tokens - self.monolithic = monolithic - if self.monolithic: - self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim)) - else: - self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) + self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb) # I imagine this is an oversight in the NAR. @@ -78,14 +75,16 @@ class MultiEmbedding(nn.Embedding): if len(x_list) == 0: return [] + # this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR + # the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb if self.monolithic: - w = self.weights[0 if quant_levels is None else 1] + w = self.weight[:1] if quant_levels is None else self.weight[1:] else: w = self.weight padded_x_list = [] - for xi in x_list: + for i, xi in enumerate(x_list): xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k wi = w.shape[0] - xi.shape[1] xi = F.pad(xi, (0, 0, 0, wi)) # t l k