From d0a5c7eca212ad43c1dfbfe28b8eec25368c2284 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 3 Aug 2024 20:23:36 -0500 Subject: [PATCH] more coping with the NAR len --- vall_e/models/base.py | 64 +++++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index a14c850..4340f14 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -91,14 +91,14 @@ class MultiEmbedding(nn.Module): self.n_tokens = n_tokens self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) - # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb) + # to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb) # I imagine this is an oversight in the NAR. def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]: if len(x_list) == 0: return [] # this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR - # the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb + # the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb if self.monolithic: w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:] else: @@ -175,8 +175,9 @@ class AudioEmbedding(nn.Module): for i, embedding in enumerate(self.embeddings): embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape )) - def external_embeddings(self, input: Tensor) -> Tensor: - quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 + def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor: + if quant_level is None: + quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 # for AR, trim any stop tokens has_stop_token = False @@ -212,8 +213,9 @@ class AudioEmbedding(nn.Module): return embedding - def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: - quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 + def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor: + if quant_level is None: + quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 if self.sums and quant_level > 0: x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) @@ -223,11 +225,11 @@ class AudioEmbedding(nn.Module): return x - def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: - x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None + def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor: + x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None if self.external_mode and xi.shape[0] > 0: - external_embeddings = self.external_embeddings( xi ) + external_embeddings = self.external_embeddings( xi, quant_level = quant_level ) if self.external_mode == "exclusive": return external_embeddings x += external_embeddings @@ -952,9 +954,15 @@ class Base(nn.Module): # get RVQ level 0, or up to targetted RVQ level inference if self.version <= 4: - return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] ) + return self.proms_emb( + input if quant_level == 0 else input[:, :quant_level] + ) - return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 ) + return self.proms_emb( + input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], + quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level + offset = 0, + ) # yuck token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0 @@ -972,6 +980,7 @@ class Base(nn.Module): quant_level = quant_levels[batch_index] if quant_levels is not None else 0 task_type = "tts" + input_prom = None for name, input in batch_input: # technically can provide a map for input_name => embedding, but some embedding requires additional processing embedding = None @@ -992,20 +1001,32 @@ class Base(nn.Module): embedding = self.langs_emb( input ) elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input + input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ]) + embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] ) elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": if "len" in self.capabilities and quant_level == 0: - """ - # fill with "stop" tokens for NAR-only model - embedding = self.resps_emb( - torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), - offset = 0 - ) - """ - # fill with filler tokens for NAR-only model - embedding = self.dropout_token.repeat((input.shape[0], 1)) + if input_prom is not None: + # fill with the prom as the initial condition + repeat = (input.shape[0] // input_prom.shape[0]) + 1 + repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1] + + embedding = self.resps_emb( + repeated, + offset = 0, + quant_level = 0, + ) + else: + # fill with "stop" token from the len layer for the NAR-only model + embedding = self.resps_emb( + # self.dropout_token.repeat((input.shape[0], 1)), + torch.full_like(input if input.dim() == 1 else input[..., 0], 12), + offset = 0, + quant_level = 0, + ) + else: # get RVQ level 0, or up to targetted RVQ level inference if self.version <= 4: @@ -1016,7 +1037,8 @@ class Base(nn.Module): else: embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], - offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 + offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1), + quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level ) # apply token dropout