tweaks for the NAR-len model, maybe
This commit is contained in:
parent
97c5241bef
commit
66407e5bdb
|
@ -183,7 +183,7 @@ def load_engines(training=True):
|
|||
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||
keys = [
|
||||
("text_emb.weight", model.config.text_tokens ),
|
||||
("rvq_l_emb.weight", model.config.resp_levels ),
|
||||
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
|
||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
||||
|
|
|
@ -423,6 +423,11 @@ class Base(nn.Module):
|
|||
# check if requested arch is unavailable
|
||||
if self.arch_type in ERROR_ARCHES:
|
||||
raise ERROR_ARCHES[self.arch_type]
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
|
||||
if "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
|
@ -430,13 +435,7 @@ class Base(nn.Module):
|
|||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0))
|
||||
|
||||
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
|
||||
if "len" in self.capabilities and not unified_position_ids:
|
||||
|
@ -494,7 +493,7 @@ class Base(nn.Module):
|
|||
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model)
|
||||
|
||||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||
|
@ -946,7 +945,7 @@ class Base(nn.Module):
|
|||
if self.rvq_l_emb is not None:
|
||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||
quant_levels[i] = 0
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
|
@ -1112,6 +1111,7 @@ class Base(nn.Module):
|
|||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
):
|
||||
device = logits[0].device
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -1148,16 +1148,22 @@ class Base(nn.Module):
|
|||
batch_size = len(target_list)
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
for i in range(batch_size):
|
||||
if "len" in self.capabilities:
|
||||
if task_list[i] != "len":
|
||||
continue
|
||||
else: # elif "nar" in self.capabilities: # for when I stop coping and drop the NAR entirely
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
quant_level = quant_levels[i]
|
||||
task_name = task_list[i]
|
||||
|
||||
l = self.causal_size
|
||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||
causal = False
|
||||
|
||||
if "len" in self.capabilities:
|
||||
causal = task_name == "len"
|
||||
if quant_level >= self.n_resp_levels:
|
||||
quant_level = 0
|
||||
else:
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
|
||||
if causal:
|
||||
l = self.causal_size
|
||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||
|
||||
# see comments for the split-loss calc cross_entropy call
|
||||
if False:
|
||||
|
@ -1167,7 +1173,7 @@ class Base(nn.Module):
|
|||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
|
@ -1175,7 +1181,7 @@ class Base(nn.Module):
|
|||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict(
|
||||
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
|
@ -1199,6 +1205,8 @@ class Base(nn.Module):
|
|||
quant_level = quant_levels[i]
|
||||
|
||||
it = 0
|
||||
|
||||
task_name = None
|
||||
for name, input in batch:
|
||||
# do not use resp
|
||||
if name == "resp":
|
||||
|
@ -1209,6 +1217,7 @@ class Base(nn.Module):
|
|||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||
# meta-input, no corresponding token at the moment
|
||||
elif name == "task":
|
||||
task_name = input
|
||||
continue
|
||||
|
||||
seq_len = input.shape[0]
|
||||
|
@ -1216,9 +1225,17 @@ class Base(nn.Module):
|
|||
logit = logits[i][it:it+seq_len]
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
causal = False
|
||||
if "len" in self.capabilities:
|
||||
causal = task_name == "len"
|
||||
if quant_level >= self.n_resp_levels:
|
||||
quant_level = 0
|
||||
else:
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if (quant_level == 0 or "nar" not in self.capabilities) and seq_len > 1:
|
||||
if causal and seq_len > 1:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
@ -1235,6 +1252,7 @@ class Base(nn.Module):
|
|||
|
||||
for name, batch in info.items():
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if name not in ["text", "prom", "resp", "len"]:
|
||||
continue
|
||||
|
||||
|
@ -1253,7 +1271,7 @@ class Base(nn.Module):
|
|||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
|
||||
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
|
||||
self.stats["acc"][name] = metrics["acc"]
|
||||
else:
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
|
@ -1311,7 +1329,8 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
if self.classifiers is not None:
|
||||
x = self.classifiers(x, levels = quant_levels) * m
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ]
|
||||
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||
|
@ -1363,6 +1382,13 @@ class Base(nn.Module):
|
|||
|
||||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
# (AR-len) disable extraneous tokens
|
||||
if quant_levels is None and "len" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
|
@ -1375,9 +1401,6 @@ class Base(nn.Module):
|
|||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
# (NAR) disable stop token
|
||||
elif "ar" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
|
|
|
@ -295,7 +295,7 @@ class NAR(Base):
|
|||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > 10:
|
||||
r[i] = 0
|
||||
r[i][0] = stop_token
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
|
|
Loading…
Reference in New Issue
Block a user