From c47fc3274eaf901e1eb4beeedb1a9f5966135eab Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 7 Sep 2023 17:12:17 -0500 Subject: [PATCH] added backwards compat flag --- vall_e/config.py | 1 + vall_e/models/base.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 7bcdaab..2d299cb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -163,6 +163,7 @@ class Model: arch_type: str = "transformer" training: bool = True interleave: bool = False + use_multiembedding: bool = True # nasty bandaid I got myself into @property def full_name(self): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e0aab7d..164fc46 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -156,10 +156,6 @@ class Base(nn.Module): def n_embeddings(self) -> int: return self.n_resp_levels if self.monolithic else 1 - @property - def use_old_embeddings(self) -> bool: - return True - @property def stop_token(self): if not self.causal: @@ -187,11 +183,15 @@ class Base(nn.Module): p_dropout: float = 0.1, config = None, + use_multiembedding = False, ): super().__init__() self.config = config self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True + if self.config is not None and hasattr(self.config, "use_multiembedding"): + use_multiembedding = self.config.use_multiembedding + self.n_tokens = n_tokens self.d_model = d_model self.n_heads = n_heads @@ -205,14 +205,14 @@ class Base(nn.Module): # use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested # n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt - if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings: + if self.n_embeddings == self.n_prom_levels or not use_multiembedding: self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model) else: self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) # use dedicated embeddings for each RVQ-bin level in the output response / target if requested # n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs. - if self.n_embeddings > 1 or not self.use_old_embeddings: + if self.n_embeddings > 1 or not use_multiembedding: self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model) else: self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)