diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 3e527cc..7861253 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -54,7 +54,7 @@ class AR(Base): return False @property - def dual(self) -> bool: + def monolithic(self) -> bool: return False def _prune(self, l: Tensor): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 1bbad85..49ea8b9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -52,7 +52,7 @@ class AR_NAR(Base): return False @property - def dual(self) -> bool: + def monolithic(self) -> bool: return True def _prune(self, l: Tensor): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8ca5644..8a98426 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -78,7 +78,8 @@ class MultiEmbedding(nn.Embedding): for xi in x_list: xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k - xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k + wi = w.shape[0] - xi.shape[1] + xi = F.pad(xi, (0, 0, 0, wi)) # t l k padded_x_list.append(xi.to(w)) x = torch.cat(padded_x_list) # n l k @@ -87,7 +88,8 @@ class MultiEmbedding(nn.Embedding): x_list = x.split([*map(len, x_list)]) return x_list -""" + +# Embedding that sums each RVQ-bin level within a given input acoustic prompt class PromEmbedding(nn.Module): def __init__(self, n_levels, n_tokens, token_dim): super().__init__() @@ -95,11 +97,9 @@ class PromEmbedding(nn.Module): self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) def forward(self, x_list: list[Tensor] ) -> list[Tensor]: - if len(x_list) == 0: - return [] - return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ] +# Embedding that selects which embedding based on a quant_level tensor for a given batch class RespEmbedding(nn.Module): def __init__(self, n_levels, n_tokens, token_dim): super().__init__() @@ -107,11 +107,8 @@ class RespEmbedding(nn.Module): self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: - if len(x_list) == 0: - return [] - res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ] - return res -""" + return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ] + class Base(nn.Module): @property def causal(self) -> bool: @@ -150,12 +147,12 @@ class Base(nn.Module): return False @property - def dual(self) -> bool: + def monolithic(self) -> bool: return False @property def n_embeddings(self): - return self.n_resp_levels if self.dual else 1 + return self.n_resp_levels if self.monolithic else 1 @property def stop_token(self): @@ -200,17 +197,19 @@ class Base(nn.Module): self.text_emb = Embedding(n_tokens, d_model) - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - if self.n_embeddings == 1: - self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + # use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested + # n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt + if self.n_embeddings == self.n_prom_levels: + self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model) else: - self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ]) - """ - if self.n_embeddings == 1: - self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) + + # use dedicated embeddings for each RVQ-bin level in the output response / target if requested + # n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs. + if self.n_embeddings > 1: + self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model) else: - self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model) - """ + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) @@ -295,21 +294,12 @@ class Base(nn.Module): state: dict | None = None, ): - if self.n_embeddings == 1: - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.proms_emb(proms_list), - self.resps_emb(resps_list), - sep=self.sep, - ) - else: - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.proms_emb(proms_list), - self.resps_emb[0 if quant_levels is None else 1](resps_list), - #self.resps_emb(resps_list, quant_levels), - sep=self.sep, - ) + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.proms_emb(proms_list), + self.resps_emb(resps_list, quant_levels), + sep=self.sep, + ) x, m = list_to_tensor(x_list) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 878ff15..a1fe500 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -48,7 +48,7 @@ class NAR(Base): return False @property - def dual(self) -> bool: + def monolithic(self) -> bool: return False def forward(