From 66407e5bdb8a9a50382bd43e2d92bacdfd1b73e8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 3 Aug 2024 08:40:39 -0500 Subject: [PATCH] tweaks for the NAR-len model, maybe --- vall_e/engines/__init__.py | 2 +- vall_e/models/base.py | 75 +++++++++++++++++++++++++------------- vall_e/models/nar.py | 2 +- 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 6d0f083..2e3b38a 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -183,7 +183,7 @@ def load_engines(training=True): uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0 keys = [ ("text_emb.weight", model.config.text_tokens ), - ("rvq_l_emb.weight", model.config.resp_levels ), + ("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 638bcc3..f345289 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -423,6 +423,11 @@ class Base(nn.Module): # check if requested arch is unavailable if self.arch_type in ERROR_ARCHES: raise ERROR_ARCHES[self.arch_type] + audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False + split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False + tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False + audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" + unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True if "len" not in self.capabilities: # +1 to include the stop token @@ -430,13 +435,7 @@ class Base(nn.Module): l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) else: n_resp_tokens = n_audio_tokens - l_tokens = [n_resp_tokens] * self.n_resp_levels - - audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False - split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False - tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False - audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" - unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True + l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0)) # there seems to be a problem with the NAR-only model with non-unified position IDs............. if "len" in self.capabilities and not unified_position_ids: @@ -494,7 +493,7 @@ class Base(nn.Module): # this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: - self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) + self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model) # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None @@ -946,7 +945,7 @@ class Base(nn.Module): if self.rvq_l_emb is not None: # override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference) quant_levels[i] = 0 - inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) ) + inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) ) # insert input audio prompt if proms_list is not None and proms_list[i] is not None: inputs[i].append( ( "prom", proms_list[i] ) ) @@ -1112,6 +1111,7 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, ): device = logits[0].device + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1148,16 +1148,22 @@ class Base(nn.Module): batch_size = len(target_list) # modify only for the AR so it can properly behave like a transformer for i in range(batch_size): - if "len" in self.capabilities: - if task_list[i] != "len": - continue - else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely - if quant_levels is not None and quant_levels[i] > 0: - continue + quant_level = quant_levels[i] + task_name = task_list[i] - l = self.causal_size - logits[i] = logits[i][..., :-l, :] # shift the target so that token n... - target_list[i] = target_list[i][..., l:] # predicts token n + 1 + causal = False + + if "len" in self.capabilities: + causal = task_name == "len" + if quant_level >= self.n_resp_levels: + quant_level = 0 + else: + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + + if causal: + l = self.causal_size + logits[i] = logits[i][..., :-l, :] # shift the target so that token n... + target_list[i] = target_list[i][..., l:] # predicts token n + 1 # see comments for the split-loss calc cross_entropy call if False: @@ -1167,7 +1173,7 @@ class Base(nn.Module): # "nll" was in the original implementation and should actually just be called something else nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) ) - self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict( + self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict( acc = self.accuracy_metric( inputs, target ), # precision = self.precision_metric( inputs, target ), ) @@ -1175,7 +1181,7 @@ class Base(nn.Module): self.loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) - self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict( + self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) @@ -1199,6 +1205,8 @@ class Base(nn.Module): quant_level = quant_levels[i] it = 0 + + task_name = None for name, input in batch: # do not use resp if name == "resp": @@ -1209,6 +1217,7 @@ class Base(nn.Module): input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) # meta-input, no corresponding token at the moment elif name == "task": + task_name = input continue seq_len = input.shape[0] @@ -1216,9 +1225,17 @@ class Base(nn.Module): logit = logits[i][it:it+seq_len] it += seq_len + 1 # +1 to incorporate the separator + causal = False + if "len" in self.capabilities: + causal = task_name == "len" + if quant_level >= self.n_resp_levels: + quant_level = 0 + else: + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) - if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1: + if causal and seq_len > 1: l = self.causal_size logit = logit[..., :-l, :] input = input[..., l:] # shift sequence to the right by one (or causal chunk size) @@ -1235,6 +1252,7 @@ class Base(nn.Module): for name, batch in info.items(): loss_factor = self.loss_factor(name) + if name not in ["text", "prom", "resp", "len"]: continue @@ -1253,7 +1271,7 @@ class Base(nn.Module): else: self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size if self.metrics is not None: - metrics = self.metrics( batch["logits"], batch["targets"], quant_levels ) + metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels ) self.stats["acc"][name] = metrics["acc"] else: self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size @@ -1311,7 +1329,8 @@ class Base(nn.Module): ) if self.classifiers is not None: - x = self.classifiers(x, levels = quant_levels) * m + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] + x = self.classifiers(x, levels = classifier_quant_levels) * m # Remove padding logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] @@ -1363,6 +1382,13 @@ class Base(nn.Module): devices = [ logit.device for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + + # (NAR) disable stop token + if quant_levels is not None and "ar" in self.capabilities: + logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ] + # (AR-len) disable extraneous tokens + if quant_levels is None and "len" in self.capabilities: + logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ] # argmax instead if temperature <= 0.0: @@ -1375,9 +1401,6 @@ class Base(nn.Module): # (AR) perform length penalizing if quant_levels is None and self.causal: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] - # (NAR) disable stop token - elif "ar" in self.capabilities: - logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ] # perform top_k/top_p filtering of our logits if top_k > 0 or top_p < 1.0: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 2f08439..c5af2d2 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -295,7 +295,7 @@ class NAR(Base): # sanitize for i, token in enumerate(r): if token > 10: - r[i] = 0 + r[i][0] = stop_token # append tokens for i, ri in enumerate(r):