tweaks for the NAR-len model, maybe

This commit is contained in:
mrq 2024-08-03 08:40:39 -05:00
parent 97c5241bef
commit 66407e5bdb
3 changed files with 51 additions and 28 deletions

View File

@ -183,7 +183,7 @@ def load_engines(training=True):
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
keys = [
("text_emb.weight", model.config.text_tokens ),
("rvq_l_emb.weight", model.config.resp_levels ),
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),

View File

@ -423,6 +423,11 @@ class Base(nn.Module):
# check if requested arch is unavailable
if self.arch_type in ERROR_ARCHES:
raise ERROR_ARCHES[self.arch_type]
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
if "len" not in self.capabilities:
# +1 to include the stop token
@ -430,13 +435,7 @@ class Base(nn.Module):
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens
l_tokens = [n_resp_tokens] * self.n_resp_levels
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0))
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
if "len" in self.capabilities and not unified_position_ids:
@ -494,7 +493,7 @@ class Base(nn.Module):
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model)
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
@ -946,7 +945,7 @@ class Base(nn.Module):
if self.rvq_l_emb is not None:
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels[i] = 0
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
@ -1112,6 +1111,7 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
):
device = logits[0].device
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -1148,16 +1148,22 @@ class Base(nn.Module):
batch_size = len(target_list)
# modify only for the AR so it can properly behave like a transformer
for i in range(batch_size):
if "len" in self.capabilities:
if task_list[i] != "len":
continue
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
if quant_levels is not None and quant_levels[i] > 0:
continue
quant_level = quant_levels[i]
task_name = task_list[i]
l = self.causal_size
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
target_list[i] = target_list[i][..., l:] # predicts token n + 1
causal = False
if "len" in self.capabilities:
causal = task_name == "len"
if quant_level >= self.n_resp_levels:
quant_level = 0
else:
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
if causal:
l = self.causal_size
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
target_list[i] = target_list[i][..., l:] # predicts token n + 1
# see comments for the split-loss calc cross_entropy call
if False:
@ -1167,7 +1173,7 @@ class Base(nn.Module):
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
@ -1175,7 +1181,7 @@ class Base(nn.Module):
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict(
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
@ -1199,6 +1205,8 @@ class Base(nn.Module):
quant_level = quant_levels[i]
it = 0
task_name = None
for name, input in batch:
# do not use resp
if name == "resp":
@ -1209,6 +1217,7 @@ class Base(nn.Module):
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
# meta-input, no corresponding token at the moment
elif name == "task":
task_name = input
continue
seq_len = input.shape[0]
@ -1216,9 +1225,17 @@ class Base(nn.Module):
logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator
causal = False
if "len" in self.capabilities:
causal = task_name == "len"
if quant_level >= self.n_resp_levels:
quant_level = 0
else:
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
if causal and seq_len > 1:
l = self.causal_size
logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -1235,6 +1252,7 @@ class Base(nn.Module):
for name, batch in info.items():
loss_factor = self.loss_factor(name)
if name not in ["text", "prom", "resp", "len"]:
continue
@ -1253,7 +1271,7 @@ class Base(nn.Module):
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
self.stats["acc"][name] = metrics["acc"]
else:
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
@ -1311,7 +1329,8 @@ class Base(nn.Module):
)
if self.classifiers is not None:
x = self.classifiers(x, levels = quant_levels) * m
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
x = self.classifiers(x, levels = classifier_quant_levels) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
@ -1363,6 +1382,13 @@ class Base(nn.Module):
devices = [ logit.device for logit in logits ]
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
# (AR-len) disable extraneous tokens
if quant_levels is None and "len" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ]
# argmax instead
if temperature <= 0.0:
@ -1375,9 +1401,6 @@ class Base(nn.Module):
# (AR) perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
# (NAR) disable stop token
elif "ar" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:

View File

@ -295,7 +295,7 @@ class NAR(Base):
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i] = 0
r[i][0] = stop_token
# append tokens
for i, ri in enumerate(r):