From 7ce06432fd47e90069a2b19f111457b185463b76 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 6 Sep 2023 19:33:39 -0500 Subject: [PATCH] fixed the AR+NAR dual model, the resp_emb has to be split up (classifier might too) --- vall_e/models/ar.py | 4 ++++ vall_e/models/ar_nar.py | 26 +++++++------------------- vall_e/models/base.py | 32 ++++++++++++++++++++++++-------- vall_e/models/nar.py | 4 ++++ 4 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 1c15263..3e527cc 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -53,6 +53,10 @@ class AR(Base): return self.config.interleave return False + @property + def dual(self) -> bool: + return False + def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() if len(indices) == 0: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0ce24e4..ccba99c 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -16,7 +16,7 @@ class AR_NAR(Base): @property def norm_type(self): - return "ln" + return "ln" if self.n_resp_levels == 1 else "adaln" @property def arch_type(self) -> str: @@ -44,15 +44,15 @@ class AR_NAR(Base): @property def recurrent_chunk_size(self) -> int: - if cfg.mode == "training": - return 0 - return cfg.inference.recurrent_chunk_size + return 0 @property def interleave(self) -> bool: - if hasattr(self, "config") and self.config: - return self.config.interleave return False + + @property + def dual(self) -> bool: + return True def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() @@ -60,18 +60,6 @@ class AR_NAR(Base): return l return l[: indices.min().item()] - def _interleave( self, codes ): - if not self.interleave: - return codes - - return codes.flatten() - - def _deinterleave( self, codes, length = 0 ): - if not self.interleave: - return codes - - return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) ) - @staticmethod def _unsqueeze_list(x_list, axis=-1): return [x.unsqueeze(dim=axis) for x in x_list] @@ -243,7 +231,7 @@ def example_usage(): def train(): engine.train() - t = trange(5000) + t = trange(1000) for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7c9f48d..13331b1 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -126,6 +126,10 @@ class Base(nn.Module): def interleave(self) -> bool: return False + @property + def dual(self) -> bool: + return False + @property def stop_token(self): if not self.causal: @@ -151,7 +155,7 @@ class Base(nn.Module): n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, - + config = None, ): super().__init__() @@ -169,7 +173,11 @@ class Base(nn.Module): self.text_emb = Embedding(n_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + + if self.dual: + self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)]) + else: + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) @@ -254,12 +262,20 @@ class Base(nn.Module): state: dict | None = None, ): - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.proms_emb(proms_list), - self.resps_emb(resps_list), - sep=self.sep, - ) + if self.dual: + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.proms_emb(proms_list), + self.resps_emb[0 if quant_levels is None else 1](resps_list), + sep=self.sep, + ) + else: + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.proms_emb(proms_list), + self.resps_emb(resps_list), + sep=self.sep, + ) x, m = list_to_tensor(x_list) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 2409c7b..5f3f234 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -47,6 +47,10 @@ class NAR(Base): def interleave(self) -> bool: return False + @property + def dual(self) -> bool: + return False + def forward( self, text_list: list[Tensor],