diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0915f89..0b38be9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -150,20 +150,8 @@ class AR_NAR(Base): quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) else: quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - """ - if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None: - quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - else: - quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) - """ - targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) - resps_list = [r[..., 0] if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM - - """ - if cfg.experimental: - proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds - """ + resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM # append stop tokens for AR for i in range(batch_size): @@ -171,13 +159,11 @@ class AR_NAR(Base): continue resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) - targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) inputs = self.inputs( text_list=text_list, proms_list=proms_list, resps_list=resps_list, - targ_list=targ_list, lang_list=lang_list, tone_list=tone_list, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 2511735..434ffd6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -100,11 +100,12 @@ class MultiEmbedding(nn.Module): return x_list # Embedding that sums each RVQ-bin level within a given input acoustic prompt -class AudioEmbedding(nn.Module): +class AudioEmbedding_Old(nn.Module): def __init__( self, l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding + mode: "old", # old | prom | resp levels: int | None = None, # number of RVQ-bins (I don't remember the specifics) sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) ): @@ -114,10 +115,12 @@ class AudioEmbedding(nn.Module): # resp are split to where [0] is for the AR, and [1:] are reserved for NAR self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) - self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None + self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None + # + self.mode = mode # self.sums = sums - + def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: # prom if quant_levels is None and xi.shape[-1] > 1: @@ -139,6 +142,42 @@ class AudioEmbedding(nn.Module): return x +class AudioEmbedding(nn.Module): + def __init__( + self, + l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) + token_dim: int, # dimensionality of the embedding + mode: str, # prom | resp + sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) + ): + super().__init__() + # array of embeddings + # proms are [0, prom_levels] + # resp are split to where [0] is for the AR, and [1:] are reserved for NAR + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) + # + self.mode = mode + # + self.sums = sums + + # maintaining compat is hard + def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor: + if quant_level is None: + quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 + + # embeddings for AR/NAR cannot be shared + offset = 0 if self.mode == "prom" or quant_level == 0 else 1 + + if xi.dim() == 1: + x = self.embeddings[quant_level]( xi ) + elif self.sums and quant_level > 0: + x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) + else: + k = quant_level + x = self.embeddings[k + offset]( xi[:, k] ) + + return x + class Base(nn.Module): @property def causal(self) -> bool: @@ -258,17 +297,30 @@ class Base(nn.Module): n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) - else: + elif self.version < 5: # [1024] * 8 - self.proms_emb = AudioEmbedding( + self.proms_emb = AudioEmbedding_Old( [n_prom_tokens] * self.n_prom_levels, d_model, levels=self.n_prom_levels if self.version > 3 else None, + mode="prom" if self.version >= 5 else "old", sums=self.config.audio_embedding_sums if self.config is not None else True, ) # [1024 + STOP] + [1024] * 8 - self.resps_emb = AudioEmbedding( + self.resps_emb = AudioEmbedding_Old( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, levels=self.n_resp_levels if self.version > 3 else None, + mode="resp" if self.version >= 5 else "old", + sums=self.config.audio_embedding_sums if self.config is not None else True + ) + else: + self.proms_emb = AudioEmbedding( + [n_prom_tokens] * self.n_prom_levels, d_model, + "prom", + sums=self.config.audio_embedding_sums if self.config is not None else True + ) + self.resps_emb = AudioEmbedding( + [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, + "resp", sums=self.config.audio_embedding_sums if self.config is not None else True ) @@ -522,38 +574,6 @@ class Base(nn.Module): x = inputs m = mask.squeeze(-1).int() aux_loss = None - - """ - # Broken - if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"): - # prefill - if len(state) == 0: - prefill_size = x.shape[1] - # run the initial prompt to fill the KV cache - if self.arch_type == "retnet": - for n in range(prefill_size): - xi = x[:, n, :].unsqueeze(1) - self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True) - elif self.arch_type == "retnet-hf": - state = None - for n in range(prefill_size): - xi = x[:, n, :].unsqueeze(1) - - kwargs = dict( - attention_mask=m, - inputs_embeds=xi, - past_key_values=state, - use_cache=True, - forward_impl='recurrent', - # return_dict=True, - ) - - out = self.model(**kwargs) - state = out.past_key_values - - # grab last token(s) - x = x[:, -1, :].unsqueeze(1) - """ # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: @@ -564,7 +584,7 @@ class Base(nn.Module): use_cache=True, # return_dict=True, ) - if self.n_experts > 1 and targ_list is not None: + if self.n_experts > 1 and self.training: kwargs["output_router_logits"] = True t = self.model(**kwargs) @@ -574,7 +594,7 @@ class Base(nn.Module): if state is not None: state = t[1] - if self.n_experts > 1 and targ_list is not None: + if self.n_experts > 1 and self.training: router_logits = t[-1] aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) elif self.arch_type == "transformer": @@ -622,7 +642,6 @@ class Base(nn.Module): text_list: list[Tensor], proms_list: list[Tensor], resps_list: list[Tensor], - targ_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, @@ -646,8 +665,6 @@ class Base(nn.Module): inputs[i].append( ( "prom", proms_list[i] ) ) if resps_list is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - if targ_list is not None: - inputs[i].append( ( "targ", targ_list[i] ) ) return inputs @@ -669,11 +686,11 @@ class Base(nn.Module): elif name == "lang" and self.langs_emb is not None: embedding = self.langs_emb( input ) elif name == "prom": - embedding = self.proms_emb( input ) + embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] ) elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - embedding = self.resps_emb( input, quant_level ) + embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level ) else: continue @@ -698,7 +715,9 @@ class Base(nn.Module): for name, input in batch: if name == "prom": target.append( torch.full_like(input[..., 0], self.ignore_index) ) - elif name in ["text", "quant_level", "lang", "tone", "targ"]: + elif name == "resp": + target.append( input if input.dim() == 1 else input[:, quant_level-1] ) + elif name in ["text", "quant_level", "lang", "tone"]: target.append( input ) target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) @@ -755,10 +774,7 @@ class Base(nn.Module): for name, input in batch: # do not use resp if name == "resp": - continue - # rename to resp - if name == "targ": - name = "resp" + input = input if input.dim() == 1 else input[:, quant_level] # select prom level elif name == "prom" and quant_level is not None: input = input[:, quant_level] @@ -825,13 +841,15 @@ class Base(nn.Module): x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) + training = self.training # yes, there's a better way. + """ training = False for batch_index, batch in enumerate(inputs): for name, input in batch: if name == "targ": training = True - + """ device = x.device batch_size = len(x_list)